-
Notifications
You must be signed in to change notification settings - Fork 191
Expand file tree
/
Copy pathwrapper.py
More file actions
652 lines (584 loc) · 28.2 KB
/
Copy pathwrapper.py
File metadata and controls
652 lines (584 loc) · 28.2 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
"""Wrapper for cuDNN frontend to improve user experience.
This wrapper provides a more user-friendly interface for cuDNN frontend.
It allows users to create a graph, add operations to the graph, and then
compile the graph to a cuDNN plan. This wrapper is designed to avoid
boilerplate code.
Key Features:
- Automatic graph validation and compilation
- Simplified tensor management
- Support for both named and positional tensor inputs
- Automatic workspace management
- PyTorch integration with DLPack support
Example:
>>> x = torch.randn(8, 56, 56, 64, device=device, dtype=torch.float16).permute(0, 3, 1, 2)
>>> w = torch.randn(32, 3, 3, 64, device=device, dtype=torch.float16).permute(0, 3, 1, 2)
>>> with Graph() as graph:
... y = graph.conv_fprop(
... image=x, weight=w,
... padding=[1,1], stride=[1,1], dilation=[1,1],
... compute_data_type=data_type.FLOAT,
... name="conv2d",
... )
... y.set_output(True).set_data_type(data_type.HALF)
... # Graph is automatically validated and compiled on exit
>>> graph.set_io_tuples(["conv2d::image", "conv2d::weight"], ["conv2d::Y"])
>>> # Execute the graph
>>> output = graph(x, w)
"""
from collections import OrderedDict
import atexit
import itertools
import inspect
import logging
from typing import Any, Dict, List, Optional, Tuple, Union
import cudnn
import cudnn.datatypes
from cudnn import data_type, heur_mode
try:
import torch
except ImportError:
torch = None
__all__ = ["Graph", "data_type", "heur_mode", "cudnn"]
# typedefs for readability
CudnnHandle = int
_default_cudnn_handle = None
# Configure logging
logger = logging.getLogger(__name__)
def _graph_tensor(graph: cudnn.pygraph, tensor: "torch.Tensor") -> cudnn.tensor:
"""Create a tensor in the graph object.
Args:
graph: The cuDNN graph object to create the tensor in
tensor: The dlpack tensor to create a graph tensor from
Returns:
A cuDNN tensor object representing the input tensor in the graph
Note:
If the input tensor has requires_grad=True, it will be detached
before creating the graph tensor to avoid gradient tracking issues.
"""
if hasattr(tensor, "requires_grad") and tensor.requires_grad:
# PyTorch tensor with requires_grad=True need to be detached first
return graph.tensor_like(tensor.detach())
else:
return graph.tensor_like(tensor)
def _find_tensor(
tensor: Union[str, cudnn.tensor, "torch.Tensor"],
tensor_map: Dict[str, cudnn.tensor],
dlpack_map: Dict[int, cudnn.tensor],
) -> str:
"""Find the mapping name for a tensor used in a graph.
This function searches for a tensor in the tensor map and returns its
corresponding name. The tensor can be specified in multiple ways:
- As a string (either the assigned tensor name or the node::input_name)
- As a cuDNN tensor object
- As a DLPack-compatible tensor (e.g., PyTorch tensor) that was used in creating the graph
Args:
tensor: The tensor to find, can be a string name, cuDNN tensor, or DLPack tensor
tensor_map: Dictionary mapping tensor names to cuDNN tensor objects
dlpack_map: Dictionary mapping DLPack tensor IDs to cuDNN tensor objects
Returns:
The key in tensor_map that the provided tensor is mapped to
Raises:
ValueError: If the tensor cannot be found in the tensor map
"""
if isinstance(tensor, str):
# look up by canonical name, then assigned name
if tensor in tensor_map:
return tensor # this name is "node::input_name"
for tensor_name, tensor_value in tensor_map.items():
if tensor_value.get_name() == tensor:
return tensor_name # name is the assigned name of the tensor
elif isinstance(tensor, int):
# look up by tensor uid
for tensor_name, tensor_value in tensor_map.items():
if tensor_value.get_uid() == tensor:
return tensor_name
elif isinstance(tensor, cudnn.tensor):
for tensor_name, tensor_value in tensor_map.items():
if tensor is tensor_value:
return tensor_name
elif hasattr(tensor, "__dlpack__") and isinstance(dlpack_map, dict) and id(tensor) in dlpack_map:
tensor = dlpack_map[id(tensor)]
for tensor_name, tensor_value in tensor_map.items():
if tensor_value == tensor:
return tensor_name
raise ValueError("Input not found in tensor map")
def _extract_tensor(name: str, tensor: cudnn.tensor, arg_dict: dict) -> Optional["torch.Tensor"]:
"""Extract a dlpack tensor from the arg_dict that matches the provided name or cudnn tensor
Args:
name: The name of the tensor to extract
tensor: The cudnn tensor to extract
arg_dict: The dictionary of arguments to extract the tensor from
Returns:
A dlpack tensor
"""
if name in arg_dict:
return arg_dict[name] # match by canonical name
if tensor in arg_dict:
return arg_dict[tensor] # match by cudnn tensor object
try:
return arg_dict[tensor.get_name()] # match by assigned name
except KeyError:
pass
try:
return arg_dict[tensor.get_uid()] # match by tensor uid
except KeyError:
return None # not found
def _tensor_like(cudnn_tensor: cudnn.tensor, tensor_type: str = "pyt") -> "torch.Tensor":
"""Create a tensor like the provided cudnn tensor
Args:
cudnn_tensor: The cuDNN tensor to create a dlpack tensor from
tensor_type: The type of tensor to create, currently only "pyt" is supported
Returns:
A dlpack tensor allocated that is like the provided cuDNN tensor
"""
if tensor_type != "pyt":
raise NotImplementedError("Only PyTorch tensor is supported for now")
if not cudnn.datatypes.is_torch_available():
raise RuntimeError("PyTorch is not available")
dtype = cudnn.datatypes._cudnn_to_torch_data_type(cudnn_tensor.get_data_type())
if dtype is None:
raise TypeError(f"cuDNN uses an unsupported data type in PyTorch: {cudnn_tensor.get_data_type()}")
tensor = torch.empty_strided(cudnn_tensor.get_dim(), cudnn_tensor.get_stride(), device="cuda", dtype=dtype)
return tensor
def get_default_handle(stream: Optional["torch.cuda.Stream"] = None) -> CudnnHandle:
"""Get the default cuDNN handle and set to torch's current stream"""
global _default_cudnn_handle
if torch is None:
raise RuntimeError("PyTorch is not available")
if _default_cudnn_handle is None:
_default_cudnn_handle = cudnn.create_handle()
if stream is None:
stream = torch.cuda.current_stream().cuda_stream
cudnn.set_stream(handle=_default_cudnn_handle, stream=stream)
return _default_cudnn_handle
def destroy_default_handle():
if _default_cudnn_handle is not None:
cudnn.destroy_handle(_default_cudnn_handle)
atexit.register(destroy_default_handle)
class Graph:
"""Wrapper object for cuDNN computation graph
This class simplifies the process of creating, compiling, and executing
cuDNN computation graphs. It handles common boilerplate code and provides
a more Pythonic interface to the cuDNN frontend API.
Key features:
- Automatic graph validation and compilation
- Simplified tensor management with PyTorch integration
- Support for both named and positional tensor inputs
- Automatic workspace management
Note:
The graph is automatically validated and compiled when exiting the
context manager. Any errors in graph construction will be raised
at that point.
"""
__handle: Optional[CudnnHandle] = None # holding the cudnn handle pointer
def __init__(
self,
*,
handle: Optional[CudnnHandle] = None,
inputs: Optional[List[Union[str, "torch.Tensor", cudnn.tensor]]] = None,
outputs: Optional[List[Union[str, "torch.Tensor", cudnn.tensor]]] = None,
heuristics: Optional[List[heur_mode]] = None,
workspace_alloc: bool = True,
**kwargs,
) -> None:
if torch is None:
raise RuntimeError("PyTorch is not available")
if inputs and not isinstance(inputs, (list, tuple)):
raise ValueError("inputs must be a list or tuple")
if outputs and not isinstance(outputs, (list, tuple)):
raise ValueError("outputs must be a list or tuple")
if heuristics and not isinstance(heuristics, (list, tuple)):
raise ValueError("heuristics must be a list or tuple")
if cudnn.backend_version() < 91200:
raise RuntimeError("cuDNN version 9.12.0 or higher is required")
self.__kwargs = kwargs
self.__graph = None # to hold the cudnn.pygraph object
self.__tensor_map = {} # obj id of dlpack tensor -> cudnn tensor
self.__tensor_in = OrderedDict() # canonical node::argname -> cudnn tensors used as the input
self.__tensor_out = OrderedDict() # canonical node::outname -> cudnn tensors produced by the node
self.__tensor_unknown = [] # list of cuDNN tensors created by user directly
self.__node_count = {} # function name of graph node -> number of times used
self.__node_names = set() # set of assigned names of graph nodes, to check name collision
self.__input_tuples = None # tuple of input tensors, if set by set_io_tuples
self.__output_tuples = None # tuple of output tensors, if set by set_io_tuples
self.__inputs = inputs or [] # hold the list of inputs, to be used by set_io_tuples() implicitly
self.__outputs = outputs or [] # hold the list of outputs, to be used by set_io_tuples() implicitly
self.__heuristics = heuristics or [heur_mode.A, heur_mode.FALLBACK]
if not workspace_alloc:
self.__workspace = False
if handle:
self.__handle = handle
# silently replace the PyTorch dtype into cuDNN dtype
for key in ["io_data_type", "intermediate_data_type", "compute_data_type"]:
if key in kwargs:
kwargs[key] = cudnn.datatypes._torch_to_cudnn_data_type(kwargs[key]) or kwargs[key]
def __del__(self):
pass
def __enter__(self):
if self.__graph is not None:
raise RuntimeError("Graph already created")
self.__graph = cudnn.pygraph(
# Pass handle only if self.__handle is an existing handle
**({"handle": self.__handle} if self.__handle not in ["auto", None] else {}),
**self.__kwargs,
)
return self
def __exit__(self, exc_type, exc_value, tb):
"""Exit the context manager, validating and compiling the graph.
This method performs the following steps:
1. Validates the graph structure
2. Builds the operation graph
3. Creates execution plans
4. Checks hardware support
5. Builds the final plans
6. Allocates workspace memory
Raises:
ValidationError: If graph validation fails
GraphStateError: If graph operations are performed in invalid order
CudnnError: For other cuDNN-related errors
"""
# if there is an exception, clean up and propagate
if exc_type is not None:
logger.error("Exception during graph construction: %s", exc_value)
self.__graph = None
raise
# prepare the graph and build plans: Each should return None or raise exception
self.__graph.validate()
self.__graph.build_operation_graph()
self.__graph.create_execution_plans(self.__heuristics)
# TODO: let user select_behavior_notes() and select_numeric_notes() here
self.__graph.check_support()
self.__graph.build_plans()
# Set up workspace if not forbidden by user, then set up I/O tensor orders
if not hasattr(self, "__workspace"):
self.__workspace = torch.empty(
self.__graph.get_workspace_size(),
device="cuda",
dtype=torch.uint8,
)
if self.__inputs or self.__outputs:
self.set_io_tuples(self.__inputs, self.__outputs)
del self.__inputs, self.__outputs
logger.debug("Inputs: %s", self.__tensor_in)
logger.debug("Outputs: %s", self.__tensor_out)
logger.debug("Node count: %s", self.__node_count)
return self.__graph
def __getattr__(self, name: str) -> Any:
"""Intercept method calls to the graph object, usually during graph
construction, and handle them appropriately.
"""
attr = getattr(self.__graph, name)
# calling tensor_like is unnecessary, just pass through
pass_through = [
"get_workspace_size",
"get_workspace_size_plan_at_index",
"serialize",
"deserialize",
"query_tensor_attributes_of_uid",
]
if name in pass_through:
return attr
# some methods are blocked and should not be called via wrapper
# TODO should allow user select execution plan
blocked_methods = [
"build",
"build_operation_graph",
"build_plan_at_index",
"build_plans",
"check_support",
"create_execution_plan",
"create_execution_plans",
"deselect_behavior_notes",
"deselect_engines",
"deselect_numeric_notes",
"deselect_workspace_greater_than",
"execute",
"execute_plan_at_index",
"get_behavior_notes",
"get_behavior_notes_for_plan_at_index",
"get_engine_count",
"get_execution_plan_count",
"get_knobs_for_engine",
"get_plan_name_at_index",
"key",
"populate_cuda_graph",
"query_tensor_attributes_of_uid",
"select_behavior_notes",
"select_numeric_notes",
"update_cuda_graph",
"validate",
]
if name in blocked_methods:
raise RuntimeError(f"Calling {name} via wrapper is not allowed")
# non-methods: pass through. Probably not used but be safe
if not inspect.ismethod(attr):
return attr
# tensor creation methods: capture the output
def tensor_capture(*args, **kwargs):
output = attr(*args, **kwargs)
self.__tensor_unknown.append(output)
return output
if name in ["tensor", "tensor_like"]:
return tensor_capture
# other methods: wrap the method to intercept the arguments and return values
def wrapper(*args, **kwargs):
args = list(args) # shallow copy, to allow in-place modification
# determine the name of the graph node, the node may carry name attribute
if name not in self.__node_count:
self.__node_count[name] = 0
self.__node_count[name] += 1
if "name" in kwargs:
node_name = kwargs["name"]
else:
node_name = f"{name}.{self.__node_count[name]-1}"
kwargs["name"] = node_name
if node_name in self.__node_names:
raise ValueError(f"Node name {node_name} already used")
self.__node_names.add(node_name)
# process positional arguments for dlpack tensors
for i, obj in enumerate(args):
if isinstance(obj, tuple):
# use case: dropout in sdpa is a tuple of (float, seed tensor, offset tensor)
obj = list(obj) # convert to list to make it mutable
for j, elem in enumerate(obj):
if hasattr(elem, "__dlpack__"):
obj_id = id(elem)
if obj_id not in self.__tensor_map:
self.__tensor_map[obj_id] = _graph_tensor(self.__graph, elem)
obj[j] = self.__tensor_map[obj_id]
if isinstance(obj[j], cudnn.tensor):
self.__tensor_in[f"{node_name}::{i}::{j}"] = obj[j]
obj = args[i] = tuple(obj) # convert back to tuple
if hasattr(obj, "__dlpack__"):
obj_id = id(obj)
if obj_id not in self.__tensor_map:
self.__tensor_map[obj_id] = _graph_tensor(self.__graph, obj)
obj = args[i] = self.__tensor_map[obj_id]
if isinstance(obj, cudnn.tensor):
self.__tensor_in[f"{node_name}::{i}"] = obj
# process keyword arguments for dlpack tensors
for key, obj in kwargs.items():
if isinstance(obj, tuple):
# use case: dropout in sdpa is a tuple of (float, seed tensor, offset tensor)
obj = list(obj) # convert to list to make it mutable
for j, elem in enumerate(obj):
if hasattr(elem, "__dlpack__"):
obj_id = id(elem)
if obj_id not in self.__tensor_map:
self.__tensor_map[obj_id] = _graph_tensor(self.__graph, elem)
obj[j] = self.__tensor_map[obj_id]
if isinstance(obj[j], cudnn.tensor):
self.__tensor_in[f"{node_name}::{key}::{j}"] = obj[j]
obj = kwargs[key] = tuple(obj) # convert back to tuple
if hasattr(obj, "__dlpack__"):
obj_id = id(obj)
if obj_id not in self.__tensor_map:
self.__tensor_map[obj_id] = _graph_tensor(self.__graph, obj)
obj = kwargs[key] = self.__tensor_map[obj_id]
if isinstance(obj, cudnn.tensor):
self.__tensor_in[f"{node_name}::{key}"] = obj
# capturing node output
output = attr(*args, **kwargs)
if isinstance(output, cudnn.tensor):
output_list = [output]
elif isinstance(output, (list, tuple)):
output_list = output
for i, obj in enumerate(output_list):
if isinstance(obj, cudnn.tensor):
if hasattr(obj, "get_name") and obj.get_name():
tensor_name = obj.get_name()
else:
tensor_name = f"{node_name}::{i}"
self.__tensor_out[tensor_name] = obj
return output
return wrapper
def __call__(self, *args, **kwargs):
"""Execute the graph with tensor dict"""
if self.__graph is None:
raise RuntimeError("Graph not created")
if not self.__graph.get_execution_plan_count():
raise RuntimeError("You should not invoke the graph before the context exits")
if len(args) == 1 and isinstance(args[0], dict):
return self.__call_with_tensor_dict(args[0], **kwargs)
else:
if len(args) > 0 and not self.__input_tuples:
raise ValueError("You should not invoke the graph with positional arguments before running set_io_tuples()")
if len(args) != len(self.__input_tuples):
raise ValueError(f"Number of arguments ({len(args)}) does not match number of inputs ({len(self.__input_tuples)})")
return self.__call_with_positional_args(*args, **kwargs)
def __call_with_positional_args(self, *args, **kwargs) -> Union["torch.Tensor", Tuple["torch.Tensor", ...]]:
"""Execute the graph with positional arguments.
Args:
*args: Positional arguments to pass to the graph
**kwargs: Additional keyword arguments to pass to the graph execution
Returns:
A single tensor or a tuple of tensors
Note:
This method is called by __call__() when the graph is executed with positional arguments.
It is not intended to be called directly by the user. The `args` should be a list of dlpack tensors
that matches the input order of `self.__input_tuples`.
"""
# prepare the variant pack:
# all non-virtual tensors in __tensor_in and __tensor_out should be filled
variant_pack = {}
for cudnn_tensor, user_tensor in zip(self.__input_tuples, args):
variant_pack[cudnn_tensor.get_uid()] = user_tensor
output_tuple = [_tensor_like(cudnn_tensor, "pyt") for cudnn_tensor in self.__output_tuples]
for cudnn_tensor, user_tensor in zip(self.__output_tuples, output_tuple):
variant_pack[cudnn_tensor.get_uid()] = user_tensor
# execute the graph
kwargs = dict(kwargs) # shallow copy
if "handle" not in kwargs:
if self.__handle == "auto":
kwargs["handle"] = get_default_handle()
elif self.__handle is not None:
kwargs["handle"] = self.__handle
else:
raise RuntimeError("Need to specify cudnn handle to execute graph")
if "workspace" not in kwargs:
if self.__workspace is False:
raise RuntimeError("Need to specify workspace to execute graph")
else:
kwargs["workspace"] = self.__workspace
self.__graph.execute(variant_pack, **kwargs)
# return the output as a single tensor or a tuple
if len(output_tuple) == 1:
return output_tuple[0]
else:
return output_tuple
def __call_with_tensor_dict(
self,
tensor_dict: Dict[str, "torch.Tensor"],
**kwargs,
) -> Dict[str, "torch.Tensor"]:
"""Execute the graph with a dictionary of tensors.
Args:
tensor_dict: Dictionary of tensor names to tensors
**kwargs: Additional keyword arguments to pass to the graph execution
Returns:
Dictionary of tensor names to tensors
Raises:
RuntimeError: If a non-virtual tensor in the graph is not found in
`tensor_dict`, or the tensor in `tensor_dict` is not a dlpack tensor
"""
# Notes
"""
from arg tensor_dict -> variant_pack
also from self.__tensor_in + self.__tensor_out -> variant_pack
both: check if all non-virtual tensors are filled
"""
# prepare the variant pack:
# all non-virtual tensors in __tensor_in and __tensor_out should be filled
variant_pack = {}
missing_tensors = {}
for name, tensor in itertools.chain(self.__tensor_in.items(), self.__tensor_out.items()):
if tensor.get_uid() in variant_pack or tensor.get_is_virtual():
continue # already filled or not needed
user_tensor = _extract_tensor(name, tensor, tensor_dict)
if user_tensor is None:
missing_tensors[tensor] = name # overwriting existing entries
continue
if not hasattr(user_tensor, "__dlpack__"):
raise RuntimeError(f"Tensor {name} is not provided as a dlpack tensor")
variant_pack[tensor.get_uid()] = user_tensor
# check if all non-virtual tensors are filled
missing_inputs = []
missing_outputs = []
for tensor, name in missing_tensors.items():
if tensor.get_uid() in variant_pack:
continue # already filled
if name in self.__tensor_out:
# output tensor not specified, should be created automatically
variant_pack[tensor.get_uid()] = tensor_dict[name] = _tensor_like(tensor, "pyt")
missing_outputs.append(name)
else:
# input tensor not specified, flag it as missing
missing_inputs.append(name)
if missing_inputs:
raise RuntimeError(f"Non-virtual input tensors not found in variant pack: {missing_inputs}")
if missing_outputs:
logger.debug("Added output tensors: %s", missing_outputs)
# execute the graph
kwargs = dict(kwargs) # shallow copy
if "handle" not in kwargs:
if self.__handle == "auto":
kwargs["handle"] = get_default_handle()
elif self.__handle is not None:
kwargs["handle"] = self.__handle
else:
raise RuntimeError("Need to specify cudnn handle to execute graph")
if "workspace" not in kwargs:
if self.__workspace is False:
raise RuntimeError("Need to specify workspace to execute graph")
else:
kwargs["workspace"] = self.__workspace
self.__graph.execute(variant_pack, **kwargs)
return tensor_dict # by this time, the output tensors are updated
def set_io_tuples(
self,
inputs: List[Union[str, "torch.Tensor", cudnn.tensor]],
outputs: List[Union[str, "torch.Tensor", cudnn.tensor]],
) -> None:
"""Set order of input and output tensors to allow graph to be executed with positional arguments.
Args:
inputs: List of input tensors or names
outputs: List of output tensors or names
Raises:
ValueError: If inputs or outputs are not lists or tuples
"""
if not isinstance(inputs, (list, tuple)):
raise ValueError("inputs must be a list or tuple")
if not isinstance(outputs, (list, tuple)):
raise ValueError("outputs must be a list or tuple")
if not self.__graph.get_execution_plan_count():
# raise RuntimeError("You should not invoke set_io_tuples() before the context exits")
self.__inputs = inputs
self.__outputs = outputs
return
# self.__tensor_out and self.__tensor_in are dict of str -> cudnn tensor
# all non-virtual tensors should be either input or output
# Convert "inputs" to a list of names that can be looked up in __tensor_in
input_tensors = []
tensors_found = set()
for i, name in enumerate(inputs):
try:
if name in self.__tensor_unknown:
input_tensors.append(name) # user-created cuDNN tensor object
continue
name = _find_tensor(name, self.__tensor_in, self.__tensor_map)
tensor = self.__tensor_in[name]
if id(tensor) in tensors_found:
raise ValueError(f"Input at index {i} ({name}) is a duplicate")
tensors_found.add(id(tensor))
input_tensors.append(tensor)
except ValueError:
raise ValueError(f"Input at index {i} ({name}) not found in tensor map") from None
# Convert "outputs" to a list of names that can be looked up in __tensor_out
output_tensors = []
for i, name in enumerate(outputs):
try:
if name in self.__tensor_unknown:
output_tensors.append(name) # user-created cuDNN tensor object
continue
name = _find_tensor(name, self.__tensor_out, self.__tensor_map)
tensor = self.__tensor_out[name]
if id(tensor) in tensors_found:
raise ValueError(f"Output at index {i} ({name}) is a duplicate")
tensors_found.add(id(tensor))
output_tensors.append(tensor)
except ValueError:
raise ValueError(f"Output at index {i} ({name}) not found in tensor map") from None
# Verify that all input tensors are non-virtual
for i, tensor in enumerate(input_tensors):
if tensor.get_is_virtual():
raise ValueError(f"Input at index {i} is a virtual tensor")
# Verify that all non-virtual tensors are covered by input or output
for name, tensor in self.__tensor_out.items():
if not tensor.get_is_virtual() and tensor not in output_tensors:
raise ValueError(f"Node output {name} is a non-virtual tensor but not specified as output")
for name, tensor in self.__tensor_in.items():
if not tensor.get_is_virtual() and id(tensor) not in tensors_found:
raise ValueError(f"Node input {name} is a non-virtual tensor but not specified as input or output")
# Set the input and output names
self.__input_tuples = tuple(input_tensors)
self.__output_tuples = tuple(output_tensors)