Branch data Line data Source code
1 : : // types.UnionType -- used to represent e.g. Union[int, str], int | str
2 : : #include "Python.h"
3 : : #include "pycore_object.h" // _PyObject_GC_TRACK/UNTRACK
4 : : #include "pycore_unionobject.h"
5 : : #include "structmember.h"
6 : :
7 : :
8 : : static PyObject *make_union(PyObject *);
9 : :
10 : :
11 : : typedef struct {
12 : : PyObject_HEAD
13 : : PyObject *args;
14 : : PyObject *parameters;
15 : : } unionobject;
16 : :
17 : : static void
18 : 9317 : unionobject_dealloc(PyObject *self)
19 : : {
20 : 9317 : unionobject *alias = (unionobject *)self;
21 : :
22 : 9317 : _PyObject_GC_UNTRACK(self);
23 : :
24 : 9317 : Py_XDECREF(alias->args);
25 : 9317 : Py_XDECREF(alias->parameters);
26 : 9317 : Py_TYPE(self)->tp_free(self);
27 : 9317 : }
28 : :
29 : : static int
30 : 227204 : union_traverse(PyObject *self, visitproc visit, void *arg)
31 : : {
32 : 227204 : unionobject *alias = (unionobject *)self;
33 [ + - - + ]: 227204 : Py_VISIT(alias->args);
34 [ + + - + ]: 227204 : Py_VISIT(alias->parameters);
35 : 227204 : return 0;
36 : : }
37 : :
38 : : static Py_hash_t
39 : 19 : union_hash(PyObject *self)
40 : : {
41 : 19 : unionobject *alias = (unionobject *)self;
42 : 19 : PyObject *args = PyFrozenSet_New(alias->args);
43 [ - + ]: 19 : if (args == NULL) {
44 : 0 : return (Py_hash_t)-1;
45 : : }
46 : 19 : Py_hash_t hash = PyObject_Hash(args);
47 : 19 : Py_DECREF(args);
48 : 19 : return hash;
49 : : }
50 : :
51 : : static PyObject *
52 : 173 : union_richcompare(PyObject *a, PyObject *b, int op)
53 : : {
54 [ + + + + : 173 : if (!_PyUnion_Check(b) || (op != Py_EQ && op != Py_NE)) {
+ - ]
55 : 144 : Py_RETURN_NOTIMPLEMENTED;
56 : : }
57 : :
58 : 29 : PyObject *a_set = PySet_New(((unionobject*)a)->args);
59 [ - + ]: 29 : if (a_set == NULL) {
60 : 0 : return NULL;
61 : : }
62 : 29 : PyObject *b_set = PySet_New(((unionobject*)b)->args);
63 [ - + ]: 29 : if (b_set == NULL) {
64 : 0 : Py_DECREF(a_set);
65 : 0 : return NULL;
66 : : }
67 : 29 : PyObject *result = PyObject_RichCompare(a_set, b_set, op);
68 : 29 : Py_DECREF(b_set);
69 : 29 : Py_DECREF(a_set);
70 : 29 : return result;
71 : : }
72 : :
73 : : static int
74 : 9892 : is_same(PyObject *left, PyObject *right)
75 : : {
76 [ + + + + ]: 9892 : int is_ga = _PyGenericAlias_Check(left) && _PyGenericAlias_Check(right);
77 [ + + ]: 9892 : return is_ga ? PyObject_RichCompareBool(left, right, Py_EQ) : left == right;
78 : : }
79 : :
80 : : static int
81 : 9338 : contains(PyObject **items, Py_ssize_t size, PyObject *obj)
82 : : {
83 [ + + ]: 19214 : for (int i = 0; i < size; i++) {
84 : 9892 : int is_duplicate = is_same(items[i], obj);
85 [ + + ]: 9892 : if (is_duplicate) { // -1 or 1
86 : 16 : return is_duplicate;
87 : : }
88 : : }
89 : 9322 : return 0;
90 : : }
91 : :
92 : : static PyObject *
93 : 9329 : merge(PyObject **items1, Py_ssize_t size1,
94 : : PyObject **items2, Py_ssize_t size2)
95 : : {
96 : 9329 : PyObject *tuple = NULL;
97 : 9329 : Py_ssize_t pos = 0;
98 : :
99 [ + + ]: 18666 : for (int i = 0; i < size2; i++) {
100 : 9338 : PyObject *arg = items2[i];
101 : 9338 : int is_duplicate = contains(items1, size1, arg);
102 [ + + ]: 9338 : if (is_duplicate < 0) {
103 : 1 : Py_XDECREF(tuple);
104 : 1 : return NULL;
105 : : }
106 [ + + ]: 9337 : if (is_duplicate) {
107 : 15 : continue;
108 : : }
109 : :
110 [ + + ]: 9322 : if (tuple == NULL) {
111 : 9317 : tuple = PyTuple_New(size1 + size2 - i);
112 [ - + ]: 9317 : if (tuple == NULL) {
113 : 0 : return NULL;
114 : : }
115 [ + + ]: 19177 : for (; pos < size1; pos++) {
116 : 9860 : PyObject *a = items1[pos];
117 : 9860 : Py_INCREF(a);
118 : 9860 : PyTuple_SET_ITEM(tuple, pos, a);
119 : : }
120 : : }
121 : 9322 : Py_INCREF(arg);
122 : 9322 : PyTuple_SET_ITEM(tuple, pos, arg);
123 : 9322 : pos++;
124 : : }
125 : :
126 [ + + ]: 9328 : if (tuple) {
127 : 9317 : (void) _PyTuple_Resize(&tuple, pos);
128 : : }
129 : 9328 : return tuple;
130 : : }
131 : :
132 : : static PyObject **
133 : 18658 : get_types(PyObject **obj, Py_ssize_t *size)
134 : : {
135 [ + + ]: 18658 : if (*obj == Py_None) {
136 : 4321 : *obj = (PyObject *)&_PyNone_Type;
137 : : }
138 [ + + ]: 18658 : if (_PyUnion_Check(*obj)) {
139 : 533 : PyObject *args = ((unionobject *) *obj)->args;
140 : 533 : *size = PyTuple_GET_SIZE(args);
141 : 533 : return &PyTuple_GET_ITEM(args, 0);
142 : : }
143 : : else {
144 : 18125 : *size = 1;
145 : 18125 : return obj;
146 : : }
147 : : }
148 : :
149 : : static int
150 : 20506 : is_unionable(PyObject *obj)
151 : : {
152 [ + + ]: 16185 : return (obj == Py_None ||
153 [ + + ]: 23899 : PyType_Check(obj) ||
154 [ + + + + ]: 44405 : _PyGenericAlias_Check(obj) ||
155 : 1461 : _PyUnion_Check(obj));
156 : : }
157 : :
158 : : PyObject *
159 : 10254 : _Py_union_type_or(PyObject* self, PyObject* other)
160 : : {
161 [ + + + + ]: 10254 : if (!is_unionable(self) || !is_unionable(other)) {
162 : 925 : Py_RETURN_NOTIMPLEMENTED;
163 : : }
164 : :
165 : : Py_ssize_t size1, size2;
166 : 9329 : PyObject **items1 = get_types(&self, &size1);
167 : 9329 : PyObject **items2 = get_types(&other, &size2);
168 : 9329 : PyObject *tuple = merge(items1, size1, items2, size2);
169 [ + + ]: 9329 : if (tuple == NULL) {
170 [ + + ]: 12 : if (PyErr_Occurred()) {
171 : 1 : return NULL;
172 : : }
173 : 11 : Py_INCREF(self);
174 : 11 : return self;
175 : : }
176 : :
177 : 9317 : PyObject *new_union = make_union(tuple);
178 : 9317 : Py_DECREF(tuple);
179 : 9317 : return new_union;
180 : : }
181 : :
182 : : static int
183 : 52 : union_repr_item(_PyUnicodeWriter *writer, PyObject *p)
184 : : {
185 : 52 : PyObject *qualname = NULL;
186 : 52 : PyObject *module = NULL;
187 : : PyObject *tmp;
188 : 52 : PyObject *r = NULL;
189 : : int err;
190 : :
191 [ + + ]: 52 : if (p == (PyObject *)&_PyNone_Type) {
192 : 2 : return _PyUnicodeWriter_WriteASCIIString(writer, "None", 4);
193 : : }
194 : :
195 [ - + ]: 50 : if (_PyObject_LookupAttr(p, &_Py_ID(__origin__), &tmp) < 0) {
196 : 0 : goto exit;
197 : : }
198 : :
199 [ + + ]: 50 : if (tmp) {
200 : 11 : Py_DECREF(tmp);
201 [ - + ]: 11 : if (_PyObject_LookupAttr(p, &_Py_ID(__args__), &tmp) < 0) {
202 : 0 : goto exit;
203 : : }
204 [ + - ]: 11 : if (tmp) {
205 : : // It looks like a GenericAlias
206 : 11 : Py_DECREF(tmp);
207 : 11 : goto use_repr;
208 : : }
209 : : }
210 : :
211 [ - + ]: 39 : if (_PyObject_LookupAttr(p, &_Py_ID(__qualname__), &qualname) < 0) {
212 : 0 : goto exit;
213 : : }
214 [ - + ]: 39 : if (qualname == NULL) {
215 : 0 : goto use_repr;
216 : : }
217 [ - + ]: 39 : if (_PyObject_LookupAttr(p, &_Py_ID(__module__), &module) < 0) {
218 : 0 : goto exit;
219 : : }
220 [ + - - + ]: 39 : if (module == NULL || module == Py_None) {
221 : 0 : goto use_repr;
222 : : }
223 : :
224 : : // Looks like a class
225 [ + - + - ]: 78 : if (PyUnicode_Check(module) &&
226 : 39 : _PyUnicode_EqualToASCIIString(module, "builtins"))
227 : : {
228 : : // builtins don't need a module name
229 : 39 : r = PyObject_Str(qualname);
230 : 39 : goto exit;
231 : : }
232 : : else {
233 : 0 : r = PyUnicode_FromFormat("%S.%S", module, qualname);
234 : 0 : goto exit;
235 : : }
236 : :
237 : 11 : use_repr:
238 : 11 : r = PyObject_Repr(p);
239 : 50 : exit:
240 : 50 : Py_XDECREF(qualname);
241 : 50 : Py_XDECREF(module);
242 [ - + ]: 50 : if (r == NULL) {
243 : 0 : return -1;
244 : : }
245 : 50 : err = _PyUnicodeWriter_WriteStr(writer, r);
246 : 50 : Py_DECREF(r);
247 : 50 : return err;
248 : : }
249 : :
250 : : static PyObject *
251 : 24 : union_repr(PyObject *self)
252 : : {
253 : 24 : unionobject *alias = (unionobject *)self;
254 : 24 : Py_ssize_t len = PyTuple_GET_SIZE(alias->args);
255 : :
256 : : _PyUnicodeWriter writer;
257 : 24 : _PyUnicodeWriter_Init(&writer);
258 [ + + ]: 76 : for (Py_ssize_t i = 0; i < len; i++) {
259 [ + + - + ]: 52 : if (i > 0 && _PyUnicodeWriter_WriteASCIIString(&writer, " | ", 3) < 0) {
260 : 0 : goto error;
261 : : }
262 : 52 : PyObject *p = PyTuple_GET_ITEM(alias->args, i);
263 [ - + ]: 52 : if (union_repr_item(&writer, p) < 0) {
264 : 0 : goto error;
265 : : }
266 : : }
267 : 24 : return _PyUnicodeWriter_Finish(&writer);
268 : 0 : error:
269 : 0 : _PyUnicodeWriter_Dealloc(&writer);
270 : 0 : return NULL;
271 : : }
272 : :
273 : : static PyMemberDef union_members[] = {
274 : : {"__args__", T_OBJECT, offsetof(unionobject, args), READONLY},
275 : : {0}
276 : : };
277 : :
278 : : static PyObject *
279 : 5 : union_getitem(PyObject *self, PyObject *item)
280 : : {
281 : 5 : unionobject *alias = (unionobject *)self;
282 : : // Populate __parameters__ if needed.
283 [ + + ]: 5 : if (alias->parameters == NULL) {
284 : 3 : alias->parameters = _Py_make_parameters(alias->args);
285 [ - + ]: 3 : if (alias->parameters == NULL) {
286 : 0 : return NULL;
287 : : }
288 : : }
289 : :
290 : 5 : PyObject *newargs = _Py_subs_parameters(self, alias->args, alias->parameters, item);
291 [ - + ]: 5 : if (newargs == NULL) {
292 : 0 : return NULL;
293 : : }
294 : :
295 : : PyObject *res;
296 : 5 : Py_ssize_t nargs = PyTuple_GET_SIZE(newargs);
297 [ - + ]: 5 : if (nargs == 0) {
298 : 0 : res = make_union(newargs);
299 : : }
300 : : else {
301 : 5 : res = PyTuple_GET_ITEM(newargs, 0);
302 : 5 : Py_INCREF(res);
303 [ + + ]: 10 : for (Py_ssize_t iarg = 1; iarg < nargs; iarg++) {
304 : 5 : PyObject *arg = PyTuple_GET_ITEM(newargs, iarg);
305 : 5 : Py_SETREF(res, PyNumber_Or(res, arg));
306 [ - + ]: 5 : if (res == NULL) {
307 : 0 : break;
308 : : }
309 : : }
310 : : }
311 : 5 : Py_DECREF(newargs);
312 : 5 : return res;
313 : : }
314 : :
315 : : static PyMappingMethods union_as_mapping = {
316 : : .mp_subscript = union_getitem,
317 : : };
318 : :
319 : : static PyObject *
320 : 26 : union_parameters(PyObject *self, void *Py_UNUSED(unused))
321 : : {
322 : 26 : unionobject *alias = (unionobject *)self;
323 [ + + ]: 26 : if (alias->parameters == NULL) {
324 : 16 : alias->parameters = _Py_make_parameters(alias->args);
325 [ - + ]: 16 : if (alias->parameters == NULL) {
326 : 0 : return NULL;
327 : : }
328 : : }
329 : 26 : Py_INCREF(alias->parameters);
330 : 26 : return alias->parameters;
331 : : }
332 : :
333 : : static PyGetSetDef union_properties[] = {
334 : : {"__parameters__", union_parameters, (setter)NULL, "Type variables in the types.UnionType.", NULL},
335 : : {0}
336 : : };
337 : :
338 : : static PyNumberMethods union_as_number = {
339 : : .nb_or = _Py_union_type_or, // Add __or__ function
340 : : };
341 : :
342 : : static const char* const cls_attrs[] = {
343 : : "__module__", // Required for compatibility with typing module
344 : : NULL,
345 : : };
346 : :
347 : : static PyObject *
348 : 750 : union_getattro(PyObject *self, PyObject *name)
349 : : {
350 : 750 : unionobject *alias = (unionobject *)self;
351 [ + - ]: 750 : if (PyUnicode_Check(name)) {
352 : 1497 : for (const char * const *p = cls_attrs; ; p++) {
353 [ + + ]: 1497 : if (*p == NULL) {
354 : 747 : break;
355 : : }
356 [ + + ]: 750 : if (_PyUnicode_EqualToASCIIString(name, *p)) {
357 : 3 : return PyObject_GetAttr((PyObject *) Py_TYPE(alias), name);
358 : : }
359 : : }
360 : : }
361 : 747 : return PyObject_GenericGetAttr(self, name);
362 : : }
363 : :
364 : : PyObject *
365 : 38 : _Py_union_args(PyObject *self)
366 : : {
367 : : assert(_PyUnion_Check(self));
368 : 38 : return ((unionobject *) self)->args;
369 : : }
370 : :
371 : : PyTypeObject _PyUnion_Type = {
372 : : PyVarObject_HEAD_INIT(&PyType_Type, 0)
373 : : .tp_name = "types.UnionType",
374 : : .tp_doc = PyDoc_STR("Represent a PEP 604 union type\n"
375 : : "\n"
376 : : "E.g. for int | str"),
377 : : .tp_basicsize = sizeof(unionobject),
378 : : .tp_dealloc = unionobject_dealloc,
379 : : .tp_alloc = PyType_GenericAlloc,
380 : : .tp_free = PyObject_GC_Del,
381 : : .tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HAVE_GC,
382 : : .tp_traverse = union_traverse,
383 : : .tp_hash = union_hash,
384 : : .tp_getattro = union_getattro,
385 : : .tp_members = union_members,
386 : : .tp_richcompare = union_richcompare,
387 : : .tp_as_mapping = &union_as_mapping,
388 : : .tp_as_number = &union_as_number,
389 : : .tp_repr = union_repr,
390 : : .tp_getset = union_properties,
391 : : };
392 : :
393 : : static PyObject *
394 : 9317 : make_union(PyObject *args)
395 : : {
396 : : assert(PyTuple_CheckExact(args));
397 : :
398 : 9317 : unionobject *result = PyObject_GC_New(unionobject, &_PyUnion_Type);
399 [ - + ]: 9317 : if (result == NULL) {
400 : 0 : return NULL;
401 : : }
402 : :
403 : 9317 : Py_INCREF(args);
404 : 9317 : result->parameters = NULL;
405 : 9317 : result->args = args;
406 : 9317 : _PyObject_GC_TRACK(result);
407 : 9317 : return (PyObject*)result;
408 : : }
|