diff --git a/.github/dependabot.yml b/.github/dependabot.yml index e9157d2f2..b93ab648d 100644 --- a/.github/dependabot.yml +++ b/.github/dependabot.yml @@ -3,14 +3,14 @@ updates: - package-ecosystem: "cargo" directory: "/" schedule: - interval: "weekly" + interval: "monthly" - package-ecosystem: "pip" directory: "/" schedule: - interval: "weekly" + interval: "monthly" - package-ecosystem: "github-actions" directory: "/" schedule: - interval: "weekly" + interval: "monthly" diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 85174fa41..5a2b01049 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -70,7 +70,7 @@ jobs: - '3.9' - '3.10' - '3.11' - - '3.12-dev' + - '3.12' - 'pypy3.7' - 'pypy3.8' - 'pypy3.9' @@ -214,7 +214,9 @@ jobs: - run: pdm info && pdm list working-directory: pydantic - - run: pdm run pytest + # Run pytest with lax xfail because we often add tests to pydantic + # which xfail on a pending release of pydantic-core + - run: pdm run pytest --override-ini=xfail_strict=False working-directory: pydantic lint: @@ -236,7 +238,7 @@ jobs: python-version: '3.11' # used to lint js code - - uses: actions/setup-node@v3 + - uses: actions/setup-node@v4 with: node-version: '18' @@ -313,12 +315,12 @@ jobs: version: '3.1.32' actions-cache-folder: emsdk-cache - - run: pip install 'maturin>=1,<2' 'black>=22.3.0,<23' typing_extensions + - run: pip install 'maturin>=1,<2' 'ruff==0.1.3' typing_extensions - name: build wheels run: make build-wasm - - uses: actions/setup-node@v3 + - uses: actions/setup-node@v4 with: node-version: '18' @@ -389,7 +391,7 @@ jobs: interpreter: 3.11 3.12 - os: macos target: aarch64 - interpreter: 3.7 3.8 3.9 3.10 3.11 3.12 pypy3.8 pypy3.9 pypy3.10 + interpreter: 3.7 3.8 3.9 pypy3.8 pypy3.9 pypy3.10 - os: ubuntu platform: linux target: i686 @@ -440,7 +442,7 @@ jobs: python-version: '3.11' architecture: ${{ matrix.python-architecture || 'x64' }} - - run: pip install -U twine 'black>=22.3.0,<23' typing_extensions + - run: pip install -U twine 'ruff==0.1.3' typing_extensions # generate self-schema now, so we don't have to do so inside docker in maturin build - run: python generate_self_schema.py @@ -465,25 +467,26 @@ jobs: path: dist build-pgo: - name: build pgo-optimized on ${{ matrix.platform || matrix.os }} (${{ matrix.interpreter}} - ${{ matrix.target }} - ${{ matrix.manylinux || 'auto' }}) + name: build pgo-optimized on ${{ matrix.os }} / ${{ matrix.interpreter }} # only run on push to main and on release if: startsWith(github.ref, 'refs/tags/') || github.ref == 'refs/heads/main' || contains(github.event.pull_request.labels.*.name, 'Full Build') strategy: fail-fast: false matrix: - os: [ubuntu, windows] - target: [x86_64] - manylinux: [auto] - interpreter: ["3.7", "3.8", "3.9", "3.10", "3.11", "3.12-dev", "pypy3.7", "pypy3.8", "pypy3.9", "pypy3.10"] + os: [ubuntu-latest, windows-latest, macos-latest-xlarge] + interpreter: ['3.7', '3.8', '3.9', '3.10', '3.11', '3.12'] include: - - os: ubuntu - platform: linux - - os: windows + - os: windows-latest ls: dir - - interpreter: 3.12-dev - maturin-interpreter: "3.12" - - runs-on: ${{ matrix.os }}-latest + exclude: + - os: macos-latest-xlarge + interpreter: '3.7' + - os: macos-latest-xlarge + interpreter: '3.8' + - os: macos-latest-xlarge + interpreter: '3.9' + + runs-on: ${{ matrix.os }} steps: - uses: actions/checkout@v4 @@ -491,7 +494,6 @@ jobs: uses: actions/setup-python@v4 with: python-version: ${{ matrix.interpreter }} - architecture: ${{ matrix.python-architecture || 'x64' }} - name: install rust stable id: rust-toolchain @@ -499,20 +501,21 @@ jobs: with: components: llvm-tools - - run: pip install -U 'black>=22.3.0,<23' typing_extensions + - run: pip install -U 'ruff==0.1.3' typing_extensions # generate self-schema now, so we don't have to do so inside docker in maturin build - run: python generate_self_schema.py + - run: rustc --version --verbose + - name: build initial wheel uses: PyO3/maturin-action@v1 with: - target: ${{ matrix.target }} - manylinux: ${{ matrix.manylinux || 'auto' }} + manylinux: auto args: > --release --out pgo-wheel - --interpreter ${{ matrix.maturin-interpreter || matrix.interpreter }} + --interpreter ${{ matrix.interpreter }} rust-toolchain: stable docker-options: -e CI env: @@ -536,12 +539,11 @@ jobs: - name: build pgo-optimized wheel uses: PyO3/maturin-action@v1 with: - target: ${{ matrix.target }} - manylinux: ${{ matrix.manylinux || 'auto' }} + manylinux: auto args: > --release --out dist - --interpreter ${{ matrix.maturin-interpreter || matrix.interpreter }} + --interpreter ${{ matrix.interpreter }} rust-toolchain: stable docker-options: -e CI env: @@ -551,7 +553,7 @@ jobs: - uses: actions/upload-artifact@v3 with: - name: pypi_files + name: pypi_files_pgo path: dist inspect-pypi-assets: @@ -567,7 +569,19 @@ jobs: name: pypi_files path: dist - - name: list dist files + - name: list dist files before PGO builds + run: | + ls -lh dist/ + ls -l dist/ + echo "`ls dist | wc -l` files" + + - name: get PGO dist artifacts (comes after "get dist artifacts" to so these files override the non-PGO builds) + uses: actions/download-artifact@v3 + with: + name: pypi_files_pgo + path: dist + + - name: list dist files with PGO builds run: | ls -lh dist/ ls -l dist/ @@ -607,6 +621,12 @@ jobs: name: pypi_files path: dist + - name: get PGO dist artifacts (comes after "get dist artifacts" to so these files override the non-PGO builds) + uses: actions/download-artifact@v3 + with: + name: pypi_files_pgo + path: dist + - uses: uraimo/run-on-arch-action@v2.5.1 name: install & test with: @@ -659,6 +679,12 @@ jobs: name: pypi_files path: dist + - name: get PGO dist artifacts (comes after "get dist artifacts" to so these files override the non-PGO builds) + uses: actions/download-artifact@v3 + with: + name: pypi_files_pgo + path: dist + - run: pip install typing-extensions - run: pip install -r tests/requirements.txt - run: pip install pydantic-core --no-index --no-deps --find-links dist --force-reinstall @@ -688,6 +714,12 @@ jobs: name: pypi_files path: dist + - name: get PGO dist artifacts (comes after "get dist artifacts" to so these files override the non-PGO builds) + uses: actions/download-artifact@v3 + with: + name: pypi_files_pgo + path: dist + - run: twine check --strict dist/* - name: upload to pypi diff --git a/Cargo.lock b/Cargo.lock index a2db9901f..e00130f7d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4,14 +4,15 @@ version = 3 [[package]] name = "ahash" -version = "0.8.3" +version = "0.8.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2c99f64d1e06488f620f932677e24bc6e2897582980441ae90a671415bd7ec2f" +checksum = "91429305e9f0a25f6205c5b8e0d2db09e0708a7a6df0f42212bb56c32c8ac97a" dependencies = [ "cfg-if", "getrandom", "once_cell", "version_check", + "zerocopy", ] [[package]] @@ -31,9 +32,9 @@ checksum = "d468802bab17cbc0cc575e9b053f41e72aa36bfa6b7f55e3529ffa43161b97fa" [[package]] name = "base64" -version = "0.21.4" +version = "0.21.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9ba43ea6f343b788c8764558649e08df62f86c6ef251fdaeb1ffd010a9ae50a2" +checksum = "35636a1494ede3b646cc98f74f8e62c773a38a659ebc777a2cf26b9b74171df9" [[package]] name = "bitflags" @@ -62,7 +63,7 @@ dependencies = [ "once_cell", "proc-macro2", "quote", - "syn 2.0.28", + "syn", ] [[package]] @@ -125,9 +126,9 @@ dependencies = [ [[package]] name = "indoc" -version = "1.0.9" +version = "2.0.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bfa799dd5ed20a7e349f3b4639aa80d74549c81716d9ec4f994c9b5815598306" +checksum = "1e186cfbae8084e513daff4240b4797e342f988cecda4fb6c939150f96315fd8" [[package]] name = "itoa" @@ -135,6 +136,84 @@ version = "1.0.8" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "62b02a5381cc465bd3041d84623d0fa3b66738b52b8e2fc3bab8ad63ab032f4a" +[[package]] +name = "jiter" +version = "0.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b27d419c535bf7b50ad355278b1159cbf0cc8d507ea003d625b17bf0375720b8" +dependencies = [ + "ahash", + "lexical-core", + "num-bigint", + "num-traits", + "pyo3", + "smallvec", +] + +[[package]] +name = "lexical-core" +version = "0.8.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2cde5de06e8d4c2faabc400238f9ae1c74d5412d03a7bd067645ccbc47070e46" +dependencies = [ + "lexical-parse-float", + "lexical-parse-integer", + "lexical-util", + "lexical-write-float", + "lexical-write-integer", +] + +[[package]] +name = "lexical-parse-float" +version = "0.8.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "683b3a5ebd0130b8fb52ba0bdc718cc56815b6a097e28ae5a6997d0ad17dc05f" +dependencies = [ + "lexical-parse-integer", + "lexical-util", + "static_assertions", +] + +[[package]] +name = "lexical-parse-integer" +version = "0.8.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6d0994485ed0c312f6d965766754ea177d07f9c00c9b82a5ee62ed5b47945ee9" +dependencies = [ + "lexical-util", + "static_assertions", +] + +[[package]] +name = "lexical-util" +version = "0.8.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5255b9ff16ff898710eb9eb63cb39248ea8a5bb036bea8085b1a767ff6c4e3fc" +dependencies = [ + "static_assertions", +] + +[[package]] +name = "lexical-write-float" +version = "0.8.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "accabaa1c4581f05a3923d1b4cfd124c329352288b7b9da09e766b0668116862" +dependencies = [ + "lexical-util", + "lexical-write-integer", + "static_assertions", +] + +[[package]] +name = "lexical-write-integer" +version = "0.8.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e1b6f3d1f4422866b68192d62f77bc5c700bee84f3069f2469d7bc8c77852446" +dependencies = [ + "lexical-util", + "static_assertions", +] + [[package]] name = "libc" version = "0.2.147" @@ -233,21 +312,22 @@ checksum = "9b2a4787296e9989611394c33f193f676704af1686e70b8f8033ab5ba9a35a94" [[package]] name = "proc-macro2" -version = "1.0.64" +version = "1.0.69" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "78803b62cbf1f46fde80d7c0e803111524b9877184cfe7c3033659490ac7a7da" +checksum = "134c189feb4956b20f6f547d2cf727d4c0fe06722b20a0eec87ed445a97f92da" dependencies = [ "unicode-ident", ] [[package]] name = "pydantic-core" -version = "2.10.1" +version = "2.14.1" dependencies = [ "ahash", "base64", "enum_dispatch", "idna", + "jiter", "num-bigint", "pyo3", "pyo3-build-config", @@ -266,9 +346,9 @@ dependencies = [ [[package]] name = "pyo3" -version = "0.19.2" +version = "0.20.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e681a6cfdc4adcc93b4d3cf993749a4552018ee0a9b65fc0ccfad74352c72a38" +checksum = "04e8453b658fe480c3e70c8ed4e3d3ec33eb74988bd186561b0cc66b85c3bc4b" dependencies = [ "cfg-if", "indoc", @@ -284,9 +364,9 @@ dependencies = [ [[package]] name = "pyo3-build-config" -version = "0.19.2" +version = "0.20.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "076c73d0bc438f7a4ef6fdd0c3bb4732149136abd952b110ac93e4edb13a6ba5" +checksum = "a96fe70b176a89cff78f2fa7b3c930081e163d5379b4dcdf993e3ae29ca662e5" dependencies = [ "once_cell", "python3-dll-a", @@ -295,9 +375,9 @@ dependencies = [ [[package]] name = "pyo3-ffi" -version = "0.19.2" +version = "0.20.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e53cee42e77ebe256066ba8aa77eff722b3bb91f3419177cf4cd0f304d3284d9" +checksum = "214929900fd25e6604661ed9cf349727c8920d47deff196c4e28165a6ef2a96b" dependencies = [ "libc", "pyo3-build-config", @@ -305,25 +385,26 @@ dependencies = [ [[package]] name = "pyo3-macros" -version = "0.19.2" +version = "0.20.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dfeb4c99597e136528c6dd7d5e3de5434d1ceaf487436a3f03b2d56b6fc9efd1" +checksum = "dac53072f717aa1bfa4db832b39de8c875b7c7af4f4a6fe93cdbf9264cf8383b" dependencies = [ "proc-macro2", "pyo3-macros-backend", "quote", - "syn 1.0.109", + "syn", ] [[package]] name = "pyo3-macros-backend" -version = "0.19.2" +version = "0.20.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "947dc12175c254889edc0c02e399476c2f652b4b9ebd123aa655c224de259536" +checksum = "7774b5a8282bd4f25f803b1f0d945120be959a36c72e08e7cd031c792fdfd424" dependencies = [ + "heck", "proc-macro2", "quote", - "syn 1.0.109", + "syn", ] [[package]] @@ -355,9 +436,9 @@ dependencies = [ [[package]] name = "regex" -version = "1.9.5" +version = "1.10.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "697061221ea1b4a94a624f67d0ae2bfe4e22b8a17b6a192afb11046542cc8c47" +checksum = "380b951a9c5e80ddfd6136919eef32310721aa4aacd4889a8d39124b026ab343" dependencies = [ "aho-corasick", "memchr", @@ -367,9 +448,9 @@ dependencies = [ [[package]] name = "regex-automata" -version = "0.3.8" +version = "0.4.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c2f401f4955220693b56f8ec66ee9c78abffd8d1c4f23dc41a23839eb88f0795" +checksum = "5f804c7828047e88b2d32e2d7fe5a105da8ee3264f01902f796c8e067dc2483f" dependencies = [ "aho-corasick", "memchr", @@ -378,9 +459,9 @@ dependencies = [ [[package]] name = "regex-syntax" -version = "0.7.5" +version = "0.8.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dbb5fb1acd8a1a18b3dd5be62d25485eb770e05afb408a9627d14d451bae12da" +checksum = "c08c74e62047bb2de4ff487b251e4a92e24f48745648451635cec7d591162d9f" [[package]] name = "rustversion" @@ -402,29 +483,29 @@ checksum = "d29ab0c6d3fc0ee92fe66e2d99f700eab17a8d57d1c1d3b748380fb20baa78cd" [[package]] name = "serde" -version = "1.0.188" +version = "1.0.190" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cf9e0fcba69a370eed61bcf2b728575f726b50b55cba78064753d708ddc7549e" +checksum = "91d3c334ca1ee894a2c6f6ad698fe8c435b76d504b13d436f0685d648d6d96f7" dependencies = [ "serde_derive", ] [[package]] name = "serde_derive" -version = "1.0.188" +version = "1.0.190" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4eca7ac642d82aa35b60049a6eccb4be6be75e599bd2e9adb5f875a737654af2" +checksum = "67c5609f394e5c2bd7fc51efda478004ea80ef42fee983d5c67a65e34f32c0e3" dependencies = [ "proc-macro2", "quote", - "syn 2.0.28", + "syn", ] [[package]] name = "serde_json" -version = "1.0.107" +version = "1.0.108" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6b420ce6e3d8bd882e9b243c6eed35dbc9a6110c9769e74b584e0d68d1f20c65" +checksum = "3d1c7e3eac408d115102c4c24ad393e0821bb3a5df4d506a80f85f7a742a526b" dependencies = [ "indexmap", "itoa", @@ -440,14 +521,20 @@ checksum = "942b4a808e05215192e39f4ab80813e599068285906cc91aa64f923db842bd5a" [[package]] name = "speedate" -version = "0.12.0" +version = "0.13.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c028e117e67c1f3224f5f834b3e48d4133dc11ec509aa19fdfa6c0987efed332" +checksum = "242f76c50fd18cbf098607090ade73a08d39cfd84ea835f3796a2c855223b19b" dependencies = [ "strum", "strum_macros", ] +[[package]] +name = "static_assertions" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a2eb9349b6444b326872e140eb1cf5e7c522154d69e7a0ffb0fb81c06b37543f" + [[package]] name = "strum" version = "0.25.0" @@ -459,33 +546,22 @@ dependencies = [ [[package]] name = "strum_macros" -version = "0.25.2" +version = "0.25.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ad8d03b598d3d0fff69bf533ee3ef19b8eeb342729596df84bcc7e1f96ec4059" +checksum = "23dc1fa9ac9c169a78ba62f0b841814b7abae11bdd047b9c58f893439e309ea0" dependencies = [ "heck", "proc-macro2", "quote", "rustversion", - "syn 2.0.28", + "syn", ] [[package]] name = "syn" -version = "1.0.109" +version = "2.0.38" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "72b64191b275b66ffe2469e8af2c1cfe3bafa67b529ead792a6d0160888b4237" -dependencies = [ - "proc-macro2", - "quote", - "unicode-ident", -] - -[[package]] -name = "syn" -version = "2.0.28" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "04361975b3f5e348b2189d8dc55bc942f278b2d482a6a0365de5bdd62d351567" +checksum = "e96b79aaa137db8f61e26363a0c9b47d8b4ec75da28b7d1d614c2303e232408b" dependencies = [ "proc-macro2", "quote", @@ -536,9 +612,9 @@ dependencies = [ [[package]] name = "unindent" -version = "0.1.11" +version = "0.2.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e1766d682d402817b5ac4490b3c3002d91dfa0d22812f341609f97b08757359c" +checksum = "c7de7d73e1754487cb58364ee906a499937a0dfabd86bcb980fa99ec8c8fa2ce" [[package]] name = "url" @@ -553,9 +629,9 @@ dependencies = [ [[package]] name = "uuid" -version = "1.4.1" +version = "1.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "79daa5ed5740825c40b389c5e50312b9c86df53fccd33f281df655642b43869d" +checksum = "88ad59a7560b41a70d191093a945f0b87bc1deeda46fb237479708a1d6b6cdfc" [[package]] name = "version_check" @@ -625,3 +701,23 @@ name = "windows_x86_64_msvc" version = "0.48.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1a515f5799fe4961cb532f983ce2b23082366b898e52ffbce459c86f67c8378a" + +[[package]] +name = "zerocopy" +version = "0.7.20" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dd66a62464e3ffd4e37bd09950c2b9dd6c4f8767380fabba0d523f9a775bc85a" +dependencies = [ + "zerocopy-derive", +] + +[[package]] +name = "zerocopy-derive" +version = "0.7.20" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "255c4596d41e6916ced49cfafea18727b24d67878fa180ddfd69b9df34fd1726" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] diff --git a/Cargo.toml b/Cargo.toml index 7101a7019..f7c859534 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "pydantic-core" -version = "2.10.1" +version = "2.14.1" edition = "2021" license = "MIT" homepage = "https://github.com/pydantic/pydantic-core" @@ -26,23 +26,25 @@ include = [ ] [dependencies] -pyo3 = { version = "0.19.2", features = ["generate-import-lib", "num-bigint"] } -regex = "1.9.5" +pyo3 = { version = "0.20.0", features = ["generate-import-lib", "num-bigint"] } +regex = "1.10.2" strum = { version = "0.25.0", features = ["derive"] } -strum_macros = "0.25.2" -serde_json = {version = "1.0.107", features = ["arbitrary_precision", "preserve_order"]} +strum_macros = "0.25.3" +serde_json = {version = "1.0.108", features = ["arbitrary_precision", "preserve_order"]} enum_dispatch = "0.3.8" -serde = { version = "1.0.188", features = ["derive"] } -speedate = "0.12.0" +serde = { version = "1.0.190", features = ["derive"] } +speedate = "0.13.0" smallvec = "1.11.1" -ahash = "0.8.0" +ahash = "0.8.6" url = "2.4.1" # idna is already required by url, added here to be explicit idna = "0.4.0" -base64 = "0.21.4" +base64 = "0.21.5" num-bigint = "0.4.4" python3-dll-a = "0.2.7" -uuid = "1.4.1" +uuid = "1.5.0" +jiter = {version = "0.0.4", features = ["python"]} +#jiter = {path = "../jiter", features = ["python"]} [lib] name = "_pydantic_core" @@ -62,9 +64,9 @@ debug = true strip = false [dev-dependencies] -pyo3 = { version= "0.19.2", features = ["auto-initialize"] } +pyo3 = { version = "0.20.0", features = ["auto-initialize"] } [build-dependencies] version_check = "0.9.4" # used where logic has to be version/distribution specific, e.g. pypy -pyo3-build-config = { version = "0.19.2" } +pyo3-build-config = { version = "0.20.0" } diff --git a/Makefile b/Makefile index 4e97e6f42..d9e0d0e0a 100644 --- a/Makefile +++ b/Makefile @@ -1,6 +1,6 @@ .DEFAULT_GOAL := all -black = black python/pydantic_core tests generate_self_schema.py wasm-preview/run_tests.py -ruff = ruff python/pydantic_core tests generate_self_schema.py wasm-preview/run_tests.py +sources = python/pydantic_core tests generate_self_schema.py wasm-preview/run_tests.py + mypy-stubtest = python -m mypy.stubtest pydantic_core._pydantic_core --allowlist .mypy-stubtest-allowlist # using pip install cargo (via maturin via pip) doesn't get the tty handle @@ -90,14 +90,14 @@ build-wasm: .PHONY: format format: - $(black) - $(ruff) --fix --exit-zero + ruff --fix $(sources) + ruff format $(sources) cargo fmt .PHONY: lint-python lint-python: - $(ruff) - $(black) --check --diff + ruff $(sources) + ruff format --check $(sources) $(mypy-stubtest) griffe dump -f -d google -LWARNING -o/dev/null python/pydantic_core diff --git a/benches/main.rs b/benches/main.rs index 9d46131d1..4b8a2b106 100644 --- a/benches/main.rs +++ b/benches/main.rs @@ -458,6 +458,50 @@ fn complete_model(bench: &mut Bencher) { }) } +#[bench] +fn nested_model_using_definitions(bench: &mut Bencher) { + Python::with_gil(|py| { + let sys_path = py.import("sys").unwrap().getattr("path").unwrap(); + sys_path.call_method1("append", ("./tests/benchmarks/",)).unwrap(); + + let complete_schema = py.import("nested_schema").unwrap(); + let mut schema = complete_schema.call_method0("schema_using_defs").unwrap(); + schema = validate_core_schema(py, schema, None).unwrap().extract().unwrap(); + let validator = SchemaValidator::py_new(py, schema, None).unwrap(); + + let input = complete_schema.call_method0("input_data_valid").unwrap(); + let input = black_box(input); + + validator.validate_python(py, input, None, None, None, None).unwrap(); + + bench.iter(|| { + black_box(validator.validate_python(py, input, None, None, None, None).unwrap()); + }) + }) +} + +#[bench] +fn nested_model_inlined(bench: &mut Bencher) { + Python::with_gil(|py| { + let sys_path = py.import("sys").unwrap().getattr("path").unwrap(); + sys_path.call_method1("append", ("./tests/benchmarks/",)).unwrap(); + + let complete_schema = py.import("nested_schema").unwrap(); + let mut schema = complete_schema.call_method0("inlined_schema").unwrap(); + schema = validate_core_schema(py, schema, None).unwrap().extract().unwrap(); + let validator = SchemaValidator::py_new(py, schema, None).unwrap(); + + let input = complete_schema.call_method0("input_data_valid").unwrap(); + let input = black_box(input); + + validator.validate_python(py, input, None, None, None, None).unwrap(); + + bench.iter(|| { + black_box(validator.validate_python(py, input, None, None, None, None).unwrap()); + }) + }) +} + #[bench] fn literal_ints_few_python(bench: &mut Bencher) { Python::with_gil(|py| { diff --git a/pyproject.toml b/pyproject.toml index 7afacd57a..a40f160f0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -57,13 +57,15 @@ features = ["pyo3/extension-module"] line-length = 120 extend-select = ['Q', 'RUF100', 'C90', 'I'] extend-ignore = [ - 'E501', # ignore line too long and let black handle it 'E721', # using type() instead of isinstance() - we use this in tests ] flake8-quotes = {inline-quotes = 'single', multiline-quotes = 'double'} mccabe = { max-complexity = 13 } isort = { known-first-party = ['pydantic_core', 'tests'] } +[tool.ruff.format] +quote-style = 'single' + [tool.pytest.ini_options] testpaths = 'tests' log_format = '%(name)s %(levelname)s: %(message)s' @@ -97,13 +99,6 @@ exclude_lines = [ '@overload', ] -[tool.black] -color = true -line-length = 120 -target-version = ['py37', 'py38', 'py39', 'py310'] -skip-string-normalization = true -skip-magic-trailing-comma = true - # configuring https://github.com/pydantic/hooky [tool.hooky] assignees = ['samuelcolvin', 'adriangb', 'dmontagu', 'davidhewitt', 'lig'] diff --git a/python/pydantic_core/__init__.py b/python/pydantic_core/__init__.py index a46a77b7d..5b2655c91 100644 --- a/python/pydantic_core/__init__.py +++ b/python/pydantic_core/__init__.py @@ -22,6 +22,7 @@ Url, ValidationError, __version__, + from_json, to_json, to_jsonable_python, validate_core_schema, @@ -63,6 +64,7 @@ 'PydanticSerializationUnexpectedValue', 'TzInfo', 'to_json', + 'from_json', 'to_jsonable_python', 'validate_core_schema', ] diff --git a/python/pydantic_core/_pydantic_core.pyi b/python/pydantic_core/_pydantic_core.pyi index 8ed3092a9..b452d2f17 100644 --- a/python/pydantic_core/_pydantic_core.pyi +++ b/python/pydantic_core/_pydantic_core.pyi @@ -41,6 +41,7 @@ __all__ = [ 'PydanticUndefinedType', 'Some', 'to_json', + 'from_json', 'to_jsonable_python', 'list_all_errors', 'TzInfo', @@ -384,6 +385,23 @@ def to_json( JSON bytes. """ +def from_json(data: str | bytes | bytearray, *, allow_inf_nan: bool = True) -> Any: + """ + Deserialize JSON data to a Python object. + + This is effectively a faster version of [`json.loads()`][json.loads]. + + Arguments: + data: The JSON data to deserialize. + allow_inf_nan: Whether to allow `Infinity`, `-Infinity` and `NaN` values as `json.loads()` does by default. + + Raises: + ValueError: If deserialization fails. + + Returns: + The deserialized Python object. + """ + def to_jsonable_python( value: Any, *, @@ -829,7 +847,6 @@ def list_all_errors() -> list[ErrorTypeInfo]: Returns: A list of `ErrorTypeInfo` typed dicts. """ - @final class TzInfo(datetime.tzinfo): def tzname(self, _dt: datetime.datetime | None) -> str | None: ... diff --git a/python/pydantic_core/core_schema.py b/python/pydantic_core/core_schema.py index 2d7061ffd..fec3b9966 100644 --- a/python/pydantic_core/core_schema.py +++ b/python/pydantic_core/core_schema.py @@ -68,6 +68,8 @@ class CoreConfig(TypedDict, total=False): allow_inf_nan: Whether to allow infinity and NaN values for float fields. Default is `True`. ser_json_timedelta: The serialization option for `timedelta` values. Default is 'iso8601'. ser_json_bytes: The serialization option for `bytes` values. Default is 'utf8'. + ser_json_inf_nan: The serialization option for infinity and NaN values + in float fields. Default is 'null'. hide_input_in_errors: Whether to hide input data from `ValidationError` representation. validation_error_cause: Whether to add user-python excs to the __cause__ of a ValidationError. Requires exceptiongroup backport pre Python 3.11. @@ -101,11 +103,13 @@ class CoreConfig(TypedDict, total=False): allow_inf_nan: bool # default: True # the config options are used to customise serialization to JSON ser_json_timedelta: Literal['iso8601', 'float'] # default: 'iso8601' - ser_json_bytes: Literal['utf8', 'base64'] # default: 'utf8' + ser_json_bytes: Literal['utf8', 'base64', 'hex'] # default: 'utf8' + ser_json_inf_nan: Literal['null', 'constants'] # default: 'null' # used to hide input data from ValidationError repr hide_input_in_errors: bool validation_error_cause: bool # default: False coerce_numbers_to_str: bool # default: False + regex_engine: Literal['rust-regex', 'python-re'] # default: 'rust-regex' IncExCall: TypeAlias = 'set[int | str] | dict[int | str, IncExCall] | None' @@ -3909,6 +3913,9 @@ def general_plain_validator_function(*args, **kwargs): 'FieldWrapValidatorFunction': WithInfoWrapValidatorFunction, } +if TYPE_CHECKING: + FieldValidationInfo = ValidationInfo + def __getattr__(attr_name: str) -> object: new_attr = _deprecated_import_lookup.get(attr_name) diff --git a/src/definitions.rs b/src/definitions.rs index 0d01fd2ae..4627fd2d1 100644 --- a/src/definitions.rs +++ b/src/definitions.rs @@ -3,16 +3,20 @@ /// Unlike json schema we let you put definitions inline, not just in a single '#/$defs/' block or similar. /// We use DefinitionsBuilder to collect the references / definitions into a single vector /// and then get a definition from a reference using an integer id (just for performance of not using a HashMap) -use std::collections::hash_map::Entry; +use std::{ + collections::hash_map::Entry, + fmt::Debug, + sync::{ + atomic::{AtomicBool, Ordering}, + Arc, OnceLock, + }, +}; -use pyo3::prelude::*; +use pyo3::{prelude::*, PyTraverseError, PyVisit}; use ahash::AHashMap; -use crate::build_tools::py_schema_err; - -// An integer id for the reference -pub type ReferenceId = usize; +use crate::{build_tools::py_schema_err, py_gc::PyGcTraverse}; /// Definitions are validators and serializers that are /// shared by reference. @@ -24,91 +28,215 @@ pub type ReferenceId = usize; /// They get indexed by a ReferenceId, which are integer identifiers /// that are handed out and managed by DefinitionsBuilder when the Schema{Validator,Serializer} /// gets build. -pub type Definitions = [T]; +#[derive(Clone)] +pub struct Definitions(AHashMap, Definition>); -#[derive(Clone, Debug)] -struct Definition { - pub id: ReferenceId, - pub value: Option, +/// Internal type which contains a definition to be filled +pub struct Definition(Arc>); + +struct DefinitionInner { + value: OnceLock, + name: LazyName, +} + +/// Reference to a definition. +pub struct DefinitionRef { + name: Arc, + value: Definition, +} + +// DefinitionRef can always be cloned (#[derive(Clone)] would require T: Clone) +impl Clone for DefinitionRef { + fn clone(&self) -> Self { + Self { + name: self.name.clone(), + value: self.value.clone(), + } + } +} + +impl DefinitionRef { + pub fn id(&self) -> usize { + Arc::as_ptr(&self.value.0) as usize + } + + pub fn get_or_init_name(&self, init: impl FnOnce(&T) -> String) -> &str { + match self.value.0.value.get() { + Some(value) => self.value.0.name.get_or_init(|| init(value)), + None => "...", + } + } + + pub fn get(&self) -> Option<&T> { + self.value.0.value.get() + } +} + +impl Debug for DefinitionRef { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + // To avoid possible infinite recursion from recursive definitions, + // a DefinitionRef just displays debug as its name + self.name.fmt(f) + } +} + +impl Debug for Definitions { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + // Formatted as a list for backwards compatibility; in principle + // this could be formatted as a map. Maybe change in a future + // minor release of pydantic. + write![f, "["]?; + let mut first = true; + for def in self.0.values() { + write![f, "{sep}{def:?}", sep = if first { "" } else { ", " }]?; + first = false; + } + write![f, "]"]?; + Ok(()) + } +} + +impl Clone for Definition { + fn clone(&self) -> Self { + Self(self.0.clone()) + } +} + +impl Debug for Definition { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self.0.value.get() { + Some(value) => value.fmt(f), + None => "...".fmt(f), + } + } +} + +impl PyGcTraverse for DefinitionRef { + fn py_gc_traverse(&self, visit: &PyVisit<'_>) -> Result<(), PyTraverseError> { + if let Some(value) = self.value.0.value.get() { + value.py_gc_traverse(visit)?; + } + Ok(()) + } +} + +impl PyGcTraverse for Definitions { + fn py_gc_traverse(&self, visit: &PyVisit<'_>) -> Result<(), PyTraverseError> { + for value in self.0.values() { + if let Some(value) = value.0.value.get() { + value.py_gc_traverse(visit)?; + } + } + Ok(()) + } } #[derive(Clone, Debug)] pub struct DefinitionsBuilder { - definitions: AHashMap>, + definitions: Definitions, } -impl DefinitionsBuilder { +impl DefinitionsBuilder { pub fn new() -> Self { Self { - definitions: AHashMap::new(), + definitions: Definitions(AHashMap::new()), } } /// Get a ReferenceId for the given reference string. - // This ReferenceId can later be used to retrieve a definition - pub fn get_reference_id(&mut self, reference: &str) -> ReferenceId { - let next_id = self.definitions.len(); + pub fn get_definition(&mut self, reference: &str) -> DefinitionRef { // We either need a String copy or two hashmap lookups // Neither is better than the other // We opted for the easier outward facing API - match self.definitions.entry(reference.to_string()) { - Entry::Occupied(entry) => entry.get().id, - Entry::Vacant(entry) => { - entry.insert(Definition { - id: next_id, - value: None, - }); - next_id - } + let name = Arc::new(reference.to_string()); + let value = match self.definitions.0.entry(name.clone()) { + Entry::Occupied(entry) => entry.into_mut(), + Entry::Vacant(entry) => entry.insert(Definition(Arc::new(DefinitionInner { + value: OnceLock::new(), + name: LazyName::new(), + }))), + }; + DefinitionRef { + name, + value: value.clone(), } } /// Add a definition, returning the ReferenceId that maps to it - pub fn add_definition(&mut self, reference: String, value: T) -> PyResult { - let next_id = self.definitions.len(); - match self.definitions.entry(reference.clone()) { - Entry::Occupied(mut entry) => match entry.get_mut().value.replace(value) { - Some(_) => py_schema_err!("Duplicate ref: `{}`", reference), - None => Ok(entry.get().id), - }, - Entry::Vacant(entry) => { - entry.insert(Definition { - id: next_id, - value: Some(value), - }); - Ok(next_id) + pub fn add_definition(&mut self, reference: String, value: T) -> PyResult> { + let name = Arc::new(reference); + let value = match self.definitions.0.entry(name.clone()) { + Entry::Occupied(entry) => { + let definition = entry.into_mut(); + match definition.0.value.set(value) { + Ok(()) => definition.clone(), + Err(_) => return py_schema_err!("Duplicate ref: `{}`", name), + } + } + Entry::Vacant(entry) => entry + .insert(Definition(Arc::new(DefinitionInner { + value: OnceLock::from(value), + name: LazyName::new(), + }))) + .clone(), + }; + Ok(DefinitionRef { name, value }) + } + + /// Consume this Definitions into a vector of items, indexed by each items ReferenceId + pub fn finish(self) -> PyResult> { + for (reference, def) in &self.definitions.0 { + if def.0.value.get().is_none() { + return py_schema_err!("Definitions error: definition `{}` was never filled", reference); } } + Ok(self.definitions) + } +} + +struct LazyName { + initialized: OnceLock, + in_recursion: AtomicBool, +} + +impl LazyName { + fn new() -> Self { + Self { + initialized: OnceLock::new(), + in_recursion: AtomicBool::new(false), + } } - /// Retrieve an item definition using a ReferenceId - /// If the definition doesn't yet exist (as happens in recursive types) then we create it - /// At the end (in finish()) we check that there are no undefined definitions - pub fn get_definition(&self, reference_id: ReferenceId) -> PyResult<&T> { - let (reference, def) = match self.definitions.iter().find(|(_, def)| def.id == reference_id) { - Some(v) => v, - None => return py_schema_err!("Definitions error: no definition for ReferenceId `{}`", reference_id), - }; - match def.value.as_ref() { - Some(v) => Ok(v), - None => py_schema_err!( - "Definitions error: attempted to use `{}` before it was filled", - reference - ), + /// Gets the validator name, returning the default in the case of recursion loops + fn get_or_init(&self, init: impl FnOnce() -> String) -> &str { + if let Some(s) = self.initialized.get() { + return s.as_str(); + } + + if self + .in_recursion + .compare_exchange(false, true, Ordering::SeqCst, Ordering::SeqCst) + .is_err() + { + return "..."; } + let result = self.initialized.get_or_init(init).as_str(); + self.in_recursion.store(false, Ordering::SeqCst); + result } +} - /// Consume this Definitions into a vector of items, indexed by each items ReferenceId - pub fn finish(self) -> PyResult> { - // We need to create a vec of defs according to the order in their ids - let mut defs: Vec<(usize, T)> = Vec::new(); - for (reference, def) in self.definitions { - match def.value { - None => return py_schema_err!("Definitions error: definition {} was never filled", reference), - Some(v) => defs.push((def.id, v)), - } +impl Debug for LazyName { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + self.initialized.get().map_or("...", String::as_str).fmt(f) + } +} + +impl Clone for LazyName { + fn clone(&self) -> Self { + Self { + initialized: OnceLock::new(), + in_recursion: AtomicBool::new(false), } - defs.sort_by_key(|(id, _)| *id); - Ok(defs.into_iter().map(|(_, v)| v).collect()) } } diff --git a/src/errors/line_error.rs b/src/errors/line_error.rs index e5d3c7bac..3ee4c7894 100644 --- a/src/errors/line_error.rs +++ b/src/errors/line_error.rs @@ -2,7 +2,9 @@ use pyo3::exceptions::PyTypeError; use pyo3::prelude::*; use pyo3::PyDowncastError; -use crate::input::{Input, JsonInput}; +use jiter::JsonValue; + +use crate::input::Input; use super::location::{LocItem, Location}; use super::types::ErrorType; @@ -147,7 +149,7 @@ impl<'a> ValLineError<'a> { #[derive(Clone)] pub enum InputValue<'a> { PyAny(&'a PyAny), - JsonInput(JsonInput), + JsonInput(JsonValue), String(&'a str), } diff --git a/src/errors/location.rs b/src/errors/location.rs index e5c32d5e2..8acc2a039 100644 --- a/src/errors/location.rs +++ b/src/errors/location.rs @@ -3,12 +3,11 @@ use pyo3::once_cell::GILOnceCell; use std::fmt; use pyo3::prelude::*; -use pyo3::types::{PyList, PyString, PyTuple}; +use pyo3::types::{PyList, PyTuple}; use serde::ser::SerializeSeq; use serde::{Serialize, Serializer}; use crate::lookup_key::{LookupPath, PathItem}; -use crate::tools::extract_i64; /// Used to store individual items of the error location, e.g. a string for key/field names /// or a number for array indices. @@ -35,6 +34,12 @@ impl fmt::Display for LocItem { } } +// TODO rename to ToLocItem +pub trait AsLocItem { + // TODO rename to to_loc_item + fn as_loc_item(&self) -> LocItem; +} + impl From for LocItem { fn from(s: String) -> Self { Self::S(s) @@ -82,21 +87,6 @@ impl ToPyObject for LocItem { } } -impl TryFrom<&PyAny> for LocItem { - type Error = PyErr; - - fn try_from(loc_item: &PyAny) -> PyResult { - if let Ok(py_str) = loc_item.downcast::() { - let str = py_str.to_str()?.to_string(); - Ok(Self::S(str)) - } else if let Ok(int) = extract_i64(loc_item) { - Ok(Self::I(int)) - } else { - Err(PyTypeError::new_err("Item in a location must be a string or int")) - } - } -} - impl Serialize for LocItem { fn serialize(&self, serializer: S) -> Result where @@ -211,9 +201,9 @@ impl TryFrom> for Location { fn try_from(location: Option<&PyAny>) -> PyResult { if let Some(location) = location { let mut loc_vec: Vec = if let Ok(tuple) = location.downcast::() { - tuple.iter().map(LocItem::try_from).collect::>()? + tuple.iter().map(AsLocItem::as_loc_item).collect() } else if let Ok(list) = location.downcast::() { - list.iter().map(LocItem::try_from).collect::>()? + list.iter().map(AsLocItem::as_loc_item).collect() } else { return Err(PyTypeError::new_err( "Location must be a list or tuple of strings and ints", diff --git a/src/errors/mod.rs b/src/errors/mod.rs index 6a253197f..bfc5b4329 100644 --- a/src/errors/mod.rs +++ b/src/errors/mod.rs @@ -7,7 +7,7 @@ mod validation_exception; mod value_exception; pub use self::line_error::{InputValue, ValError, ValLineError, ValResult}; -pub use self::location::LocItem; +pub use self::location::{AsLocItem, LocItem}; pub use self::types::{list_all_errors, ErrorType, ErrorTypeDefaults, Number}; pub use self::validation_exception::ValidationError; pub use self::value_exception::{PydanticCustomError, PydanticKnownError, PydanticOmit, PydanticUseDefault}; diff --git a/src/errors/types.rs b/src/errors/types.rs index d537158ba..5c3fc1a7c 100644 --- a/src/errors/types.rs +++ b/src/errors/types.rs @@ -50,7 +50,7 @@ fn field_from_context<'py, T: FromPyObject<'py>>( ) -> PyResult { context .ok_or_else(|| py_error_type!(PyTypeError; "{}: '{}' required in context", enum_name, field_name))? - .get_item(field_name) + .get_item(field_name)? .ok_or_else(|| py_error_type!(PyTypeError; "{}: '{}' required in context", enum_name, field_name))? .extract::() .map_err(|_| py_error_type!(PyTypeError; "{}: '{}' context value must be a {}", enum_name, field_name, type_name_fn())) @@ -445,8 +445,8 @@ macro_rules! to_string_render { }; } -fn plural_s(value: usize) -> &'static str { - if value == 1 { +fn plural_s + PartialEq>(value: T) -> &'static str { + if value == 1.into() { "" } else { "s" @@ -494,8 +494,8 @@ impl ErrorType { Self::StringType {..} => "Input should be a valid string", Self::StringSubType {..} => "Input should be a string, not an instance of a subclass of str", Self::StringUnicode {..} => "Input should be a valid string, unable to parse raw data as a unicode string", - Self::StringTooShort {..} => "String should have at least {min_length} characters", - Self::StringTooLong {..} => "String should have at most {max_length} characters", + Self::StringTooShort {..} => "String should have at least {min_length} character{expected_plural}", + Self::StringTooLong {..} => "String should have at most {max_length} character{expected_plural}", Self::StringPatternMismatch {..} => "String should match pattern '{pattern}'", Self::Enum {..} => "Input should be {expected}", Self::DictType {..} => "Input should be a valid dictionary", @@ -512,8 +512,8 @@ impl ErrorType { Self::FloatType {..} => "Input should be a valid number", Self::FloatParsing {..} => "Input should be a valid number, unable to parse string as a number", Self::BytesType {..} => "Input should be a valid bytes", - Self::BytesTooShort {..} => "Data should have at least {min_length} bytes", - Self::BytesTooLong {..} => "Data should have at most {max_length} bytes", + Self::BytesTooShort {..} => "Data should have at least {min_length} byte{expected_plural}", + Self::BytesTooLong {..} => "Data should have at most {max_length} byte{expected_plural}", Self::ValueError {..} => "Value error, {error}", Self::AssertionError {..} => "Assertion failed, {error}", Self::CustomError {..} => "", // custom errors are handled separately @@ -552,16 +552,16 @@ impl ErrorType { Self::UrlType {..} => "URL input should be a string or URL", Self::UrlParsing {..} => "Input should be a valid URL, {error}", Self::UrlSyntaxViolation {..} => "Input violated strict URL syntax rules, {error}", - Self::UrlTooLong {..} => "URL should have at most {max_length} characters", + Self::UrlTooLong {..} => "URL should have at most {max_length} character{expected_plural}", Self::UrlScheme {..} => "URL scheme should be {expected_schemes}", Self::UuidType {..} => "UUID input should be a string, bytes or UUID object", Self::UuidParsing {..} => "Input should be a valid UUID, {error}", Self::UuidVersion {..} => "UUID version {expected_version} expected", Self::DecimalType {..} => "Decimal input should be an integer, float, string or Decimal object", Self::DecimalParsing {..} => "Input should be a valid decimal", - Self::DecimalMaxDigits {..} => "Decimal input should have no more than {max_digits} digits in total", - Self::DecimalMaxPlaces {..} => "Decimal input should have no more than {decimal_places} decimal places", - Self::DecimalWholeDigits {..} => "Decimal input should have no more than {whole_digits} digits before the decimal point", + Self::DecimalMaxDigits {..} => "Decimal input should have no more than {max_digits} digit{expected_plural} in total", + Self::DecimalMaxPlaces {..} => "Decimal input should have no more than {decimal_places} decimal place{expected_plural}", + Self::DecimalWholeDigits {..} => "Decimal input should have no more than {whole_digits} digit{expected_plural} before the decimal point", } } @@ -643,13 +643,25 @@ impl ErrorType { to_string_render!(tmpl, field_type, max_length, actual_length, expected_plural,) } Self::IterationError { error, .. } => render!(tmpl, error), - Self::StringTooShort { min_length, .. } => to_string_render!(tmpl, min_length), - Self::StringTooLong { max_length, .. } => to_string_render!(tmpl, max_length), + Self::StringTooShort { min_length, .. } => { + let expected_plural = plural_s(*min_length); + to_string_render!(tmpl, min_length, expected_plural) + } + Self::StringTooLong { max_length, .. } => { + let expected_plural = plural_s(*max_length); + to_string_render!(tmpl, max_length, expected_plural) + } Self::StringPatternMismatch { pattern, .. } => render!(tmpl, pattern), Self::Enum { expected, .. } => to_string_render!(tmpl, expected), Self::MappingType { error, .. } => render!(tmpl, error), - Self::BytesTooShort { min_length, .. } => to_string_render!(tmpl, min_length), - Self::BytesTooLong { max_length, .. } => to_string_render!(tmpl, max_length), + Self::BytesTooShort { min_length, .. } => { + let expected_plural = plural_s(*min_length); + to_string_render!(tmpl, min_length, expected_plural) + } + Self::BytesTooLong { max_length, .. } => { + let expected_plural = plural_s(*max_length); + to_string_render!(tmpl, max_length, expected_plural) + } Self::ValueError { error, .. } => { let error = &error .as_ref() @@ -688,13 +700,25 @@ impl ErrorType { Self::UnionTagNotFound { discriminator, .. } => render!(tmpl, discriminator), Self::UrlParsing { error, .. } => render!(tmpl, error), Self::UrlSyntaxViolation { error, .. } => render!(tmpl, error), - Self::UrlTooLong { max_length, .. } => to_string_render!(tmpl, max_length), + Self::UrlTooLong { max_length, .. } => { + let expected_plural = plural_s(*max_length); + to_string_render!(tmpl, max_length, expected_plural) + } Self::UrlScheme { expected_schemes, .. } => render!(tmpl, expected_schemes), Self::UuidParsing { error, .. } => render!(tmpl, error), Self::UuidVersion { expected_version, .. } => to_string_render!(tmpl, expected_version), - Self::DecimalMaxDigits { max_digits, .. } => to_string_render!(tmpl, max_digits), - Self::DecimalMaxPlaces { decimal_places, .. } => to_string_render!(tmpl, decimal_places), - Self::DecimalWholeDigits { whole_digits, .. } => to_string_render!(tmpl, whole_digits), + Self::DecimalMaxDigits { max_digits, .. } => { + let expected_plural = plural_s(*max_digits); + to_string_render!(tmpl, max_digits, expected_plural) + } + Self::DecimalMaxPlaces { decimal_places, .. } => { + let expected_plural = plural_s(*decimal_places); + to_string_render!(tmpl, decimal_places, expected_plural) + } + Self::DecimalWholeDigits { whole_digits, .. } => { + let expected_plural = plural_s(*whole_digits); + to_string_render!(tmpl, whole_digits, expected_plural) + } _ => Ok(tmpl.to_string()), } } diff --git a/src/errors/validation_exception.rs b/src/errors/validation_exception.rs index 09154d8ac..d616e3022 100644 --- a/src/errors/validation_exception.rs +++ b/src/errors/validation_exception.rs @@ -4,10 +4,10 @@ use std::str::from_utf8; use pyo3::exceptions::{PyKeyError, PyTypeError, PyValueError}; use pyo3::ffi; +use pyo3::intern; use pyo3::once_cell::GILOnceCell; use pyo3::prelude::*; use pyo3::types::{PyDict, PyList, PyString}; -use pyo3::{intern, AsPyPointer}; use serde::ser::{Error, SerializeMap, SerializeSeq}; use serde::{Serialize, Serializer}; @@ -324,12 +324,12 @@ impl ValidationError { Some(indent) => { let indent = vec![b' '; indent]; let formatter = PrettyFormatter::with_indent(&indent); - let mut ser = serde_json::Serializer::with_formatter(writer, formatter); + let mut ser = crate::serializers::ser::PythonSerializer::with_formatter(writer, formatter); serializer.serialize(&mut ser).map_err(json_py_err)?; ser.into_inner() } None => { - let mut ser = serde_json::Serializer::new(writer); + let mut ser = crate::serializers::ser::PythonSerializer::new(writer); serializer.serialize(&mut ser).map_err(json_py_err)?; ser.into_inner() } @@ -445,7 +445,7 @@ impl TryFrom<&PyAny> for PyLineError { let py = value.py(); let type_raw = dict - .get_item(intern!(py, "type")) + .get_item(intern!(py, "type"))? .ok_or_else(|| PyKeyError::new_err("type"))?; let error_type = if let Ok(type_str) = type_raw.downcast::() { @@ -459,9 +459,9 @@ impl TryFrom<&PyAny> for PyLineError { )); }; - let location = Location::try_from(dict.get_item("loc"))?; + let location = Location::try_from(dict.get_item("loc")?)?; - let input_value = match dict.get_item("input") { + let input_value = match dict.get_item("input")? { Some(i) => i.into_py(py), None => py.None(), }; diff --git a/src/errors/value_exception.rs b/src/errors/value_exception.rs index f7d877b30..7bc7e5227 100644 --- a/src/errors/value_exception.rs +++ b/src/errors/value_exception.rs @@ -72,7 +72,7 @@ impl PydanticCustomError { } } - #[getter(type)] + #[getter(r#type)] pub fn error_type(&self) -> String { self.error_type.clone() } @@ -147,7 +147,7 @@ impl PydanticKnownError { Ok(Self { error_type }) } - #[getter(type)] + #[getter(r#type)] pub fn error_type(&self) -> String { self.error_type.to_string() } diff --git a/src/input/input_abstract.rs b/src/input/input_abstract.rs index 655ba24b9..ba6fbd0a1 100644 --- a/src/input/input_abstract.rs +++ b/src/input/input_abstract.rs @@ -4,13 +4,15 @@ use pyo3::exceptions::PyValueError; use pyo3::types::{PyDict, PyType}; use pyo3::{intern, prelude::*}; -use crate::errors::{InputValue, LocItem, ValResult}; +use jiter::JsonValue; + +use crate::errors::{AsLocItem, ErrorTypeDefaults, InputValue, ValError, ValResult}; use crate::tools::py_err; use crate::{PyMultiHostUrl, PyUrl}; use super::datetime::{EitherDate, EitherDateTime, EitherTime, EitherTimedelta}; use super::return_enums::{EitherBytes, EitherInt, EitherString}; -use super::{EitherFloat, GenericArguments, GenericIterable, GenericIterator, GenericMapping, JsonInput}; +use super::{EitherFloat, GenericArguments, GenericIterable, GenericIterator, GenericMapping, ValidationMatch}; #[derive(Debug, Clone, Copy)] pub enum InputType { @@ -46,9 +48,7 @@ impl TryFrom<&str> for InputType { /// the convention is to either implement: /// * `strict_*` & `lax_*` if they have different behavior /// * or, `validate_*` and `strict_*` to just call `validate_*` if the behavior for strict and lax is the same -pub trait Input<'a>: fmt::Debug + ToPyObject { - fn as_loc_item(&self) -> LocItem; - +pub trait Input<'a>: fmt::Debug + ToPyObject + AsLocItem + Sized { fn as_error_value(&'a self) -> InputValue<'a>; fn identity(&self) -> Option { @@ -89,87 +89,39 @@ pub trait Input<'a>: fmt::Debug + ToPyObject { fn validate_dataclass_args(&'a self, dataclass_name: &str) -> ValResult<'a, GenericArguments<'a>>; - fn parse_json(&'a self) -> ValResult<'a, JsonInput>; + fn parse_json(&'a self) -> ValResult<'a, JsonValue>; - fn validate_str(&'a self, strict: bool, coerce_numbers_to_str: bool) -> ValResult> { - if strict { - self.strict_str() - } else { - self.lax_str(coerce_numbers_to_str) - } - } - fn strict_str(&'a self) -> ValResult>; - #[cfg_attr(has_coverage_attribute, coverage(off))] - fn lax_str(&'a self, _coerce_numbers_to_str: bool) -> ValResult> { - self.strict_str() - } + fn validate_str( + &'a self, + strict: bool, + coerce_numbers_to_str: bool, + ) -> ValResult>>; - fn validate_bytes(&'a self, strict: bool) -> ValResult> { - if strict { - self.strict_bytes() - } else { - self.lax_bytes() - } - } - fn strict_bytes(&'a self) -> ValResult>; - #[cfg_attr(has_coverage_attribute, coverage(off))] - fn lax_bytes(&'a self) -> ValResult> { - self.strict_bytes() - } + fn validate_bytes(&'a self, strict: bool) -> ValResult>>; - fn validate_bool(&self, strict: bool) -> ValResult { - if strict { - self.strict_bool() - } else { - self.lax_bool() - } - } - fn strict_bool(&self) -> ValResult; - #[cfg_attr(has_coverage_attribute, coverage(off))] - fn lax_bool(&self) -> ValResult { - self.strict_bool() - } + fn validate_bool(&self, strict: bool) -> ValResult<'_, ValidationMatch>; - fn validate_int(&'a self, strict: bool) -> ValResult> { - if strict { - self.strict_int() - } else { - self.lax_int() - } - } - fn strict_int(&'a self) -> ValResult>; - #[cfg_attr(has_coverage_attribute, coverage(off))] - fn lax_int(&'a self) -> ValResult> { - self.strict_int() - } + fn validate_int(&'a self, strict: bool) -> ValResult<'a, ValidationMatch>>; - /// Extract an EitherInt from the input, only allowing exact - /// matches for an Int (no subclasses) fn exact_int(&'a self) -> ValResult> { - self.strict_int() + self.validate_int(true).and_then(|val_match| { + val_match + .require_exact() + .ok_or_else(|| ValError::new(ErrorTypeDefaults::IntType, self)) + }) } /// Extract a String from the input, only allowing exact /// matches for a String (no subclasses) fn exact_str(&'a self) -> ValResult> { - self.strict_str() + self.validate_str(true, false).and_then(|val_match| { + val_match + .require_exact() + .ok_or_else(|| ValError::new(ErrorTypeDefaults::StringType, self)) + }) } - fn validate_float(&'a self, strict: bool, ultra_strict: bool) -> ValResult> { - if ultra_strict { - self.ultra_strict_float() - } else if strict { - self.strict_float() - } else { - self.lax_float() - } - } - fn ultra_strict_float(&'a self) -> ValResult>; - fn strict_float(&'a self) -> ValResult>; - #[cfg_attr(has_coverage_attribute, coverage(off))] - fn lax_float(&'a self) -> ValResult> { - self.strict_float() - } + fn validate_float(&'a self, strict: bool) -> ValResult<'a, ValidationMatch>>; fn validate_decimal(&'a self, strict: bool, py: Python<'a>) -> ValResult<&'a PyAny> { if strict { @@ -257,87 +209,25 @@ pub trait Input<'a>: fmt::Debug + ToPyObject { fn validate_iter(&self) -> ValResult; - fn validate_date(&self, strict: bool) -> ValResult { - if strict { - self.strict_date() - } else { - self.lax_date() - } - } - fn strict_date(&self) -> ValResult; - #[cfg_attr(has_coverage_attribute, coverage(off))] - fn lax_date(&self) -> ValResult { - self.strict_date() - } + fn validate_date(&self, strict: bool) -> ValResult>; fn validate_time( &self, strict: bool, microseconds_overflow_behavior: speedate::MicrosecondsPrecisionOverflowBehavior, - ) -> ValResult { - if strict { - self.strict_time(microseconds_overflow_behavior) - } else { - self.lax_time(microseconds_overflow_behavior) - } - } - fn strict_time( - &self, - microseconds_overflow_behavior: speedate::MicrosecondsPrecisionOverflowBehavior, - ) -> ValResult; - #[cfg_attr(has_coverage_attribute, coverage(off))] - fn lax_time( - &self, - microseconds_overflow_behavior: speedate::MicrosecondsPrecisionOverflowBehavior, - ) -> ValResult { - self.strict_time(microseconds_overflow_behavior) - } + ) -> ValResult>; fn validate_datetime( &self, strict: bool, microseconds_overflow_behavior: speedate::MicrosecondsPrecisionOverflowBehavior, - ) -> ValResult { - if strict { - self.strict_datetime(microseconds_overflow_behavior) - } else { - self.lax_datetime(microseconds_overflow_behavior) - } - } - fn strict_datetime( - &self, - microseconds_overflow_behavior: speedate::MicrosecondsPrecisionOverflowBehavior, - ) -> ValResult; - #[cfg_attr(has_coverage_attribute, coverage(off))] - fn lax_datetime( - &self, - microseconds_overflow_behavior: speedate::MicrosecondsPrecisionOverflowBehavior, - ) -> ValResult { - self.strict_datetime(microseconds_overflow_behavior) - } + ) -> ValResult>; fn validate_timedelta( &self, strict: bool, microseconds_overflow_behavior: speedate::MicrosecondsPrecisionOverflowBehavior, - ) -> ValResult { - if strict { - self.strict_timedelta(microseconds_overflow_behavior) - } else { - self.lax_timedelta(microseconds_overflow_behavior) - } - } - fn strict_timedelta( - &self, - microseconds_overflow_behavior: speedate::MicrosecondsPrecisionOverflowBehavior, - ) -> ValResult; - #[cfg_attr(has_coverage_attribute, coverage(off))] - fn lax_timedelta( - &self, - microseconds_overflow_behavior: speedate::MicrosecondsPrecisionOverflowBehavior, - ) -> ValResult { - self.strict_timedelta(microseconds_overflow_behavior) - } + ) -> ValResult>; } /// The problem to solve here is that iterating a `StringMapping` returns an owned diff --git a/src/input/input_json.rs b/src/input/input_json.rs index e375f5755..0411a25d6 100644 --- a/src/input/input_json.rs +++ b/src/input/input_json.rs @@ -1,46 +1,49 @@ use std::borrow::Cow; +use jiter::{JsonArray, JsonValue}; use pyo3::prelude::*; use pyo3::types::{PyDict, PyString}; use speedate::MicrosecondsPrecisionOverflowBehavior; use strum::EnumMessage; -use crate::errors::{ErrorType, ErrorTypeDefaults, InputValue, LocItem, ValError, ValResult}; +use crate::errors::{AsLocItem, ErrorType, ErrorTypeDefaults, InputValue, LocItem, ValError, ValResult}; use crate::validators::decimal::create_decimal; use super::datetime::{ bytes_as_date, bytes_as_datetime, bytes_as_time, bytes_as_timedelta, float_as_datetime, float_as_duration, float_as_time, int_as_datetime, int_as_duration, int_as_time, EitherDate, EitherDateTime, EitherTime, }; -use super::shared::{float_as_int, int_as_bool, map_json_err, str_as_bool, str_as_float, str_as_int, string_to_vec}; +use super::return_enums::ValidationMatch; +use super::shared::{float_as_int, int_as_bool, map_json_err, str_as_bool, str_as_float, str_as_int}; use super::{ BorrowInput, EitherBytes, EitherFloat, EitherInt, EitherString, EitherTimedelta, GenericArguments, GenericIterable, - GenericIterator, GenericMapping, Input, JsonArgs, JsonArray, JsonInput, + GenericIterator, GenericMapping, Input, JsonArgs, }; -impl<'a> Input<'a> for JsonInput { - /// This is required by since JSON object keys are always strings, I don't think it can be called - #[cfg_attr(has_coverage_attribute, coverage(off))] +/// This is required but since JSON object keys are always strings, I don't think it can be called +impl AsLocItem for JsonValue { fn as_loc_item(&self) -> LocItem { match self { - JsonInput::Int(i) => (*i).into(), - JsonInput::String(s) => s.as_str().into(), + JsonValue::Int(i) => (*i).into(), + JsonValue::Str(s) => s.as_str().into(), v => format!("{v:?}").into(), } } +} +impl<'a> Input<'a> for JsonValue { fn as_error_value(&'a self) -> InputValue<'a> { - // cloning JsonInput is cheap due to use of Arc + // cloning JsonValue is cheap due to use of Arc InputValue::JsonInput(self.clone()) } fn is_none(&self) -> bool { - matches!(self, JsonInput::Null) + matches!(self, JsonValue::Null) } fn as_kwargs(&'a self, py: Python<'a>) -> Option<&'a PyDict> { match self { - JsonInput::Object(object) => { + JsonValue::Object(object) => { let dict = PyDict::new(py); for (k, v) in object.iter() { dict.set_item(k, v.to_object(py)).unwrap(); @@ -53,15 +56,15 @@ impl<'a> Input<'a> for JsonInput { fn validate_args(&'a self) -> ValResult<'a, GenericArguments<'a>> { match self { - JsonInput::Object(object) => Ok(JsonArgs::new(None, Some(object)).into()), - JsonInput::Array(array) => Ok(JsonArgs::new(Some(array), None).into()), + JsonValue::Object(object) => Ok(JsonArgs::new(None, Some(object)).into()), + JsonValue::Array(array) => Ok(JsonArgs::new(Some(array), None).into()), _ => Err(ValError::new(ErrorTypeDefaults::ArgumentsType, self)), } } fn validate_dataclass_args(&'a self, class_name: &str) -> ValResult<'a, GenericArguments<'a>> { match self { - JsonInput::Object(object) => Ok(JsonArgs::new(None, Some(object)).into()), + JsonValue::Object(object) => Ok(JsonArgs::new(None, Some(object)).into()), _ => { let class_name = class_name.to_string(); Err(ValError::new( @@ -75,118 +78,88 @@ impl<'a> Input<'a> for JsonInput { } } - fn parse_json(&'a self) -> ValResult<'a, JsonInput> { + fn parse_json(&'a self) -> ValResult<'a, JsonValue> { match self { - JsonInput::String(s) => serde_json::from_str(s.as_str()).map_err(|e| map_json_err(self, e)), + JsonValue::Str(s) => JsonValue::parse(s.as_bytes(), true).map_err(|e| map_json_err(self, e)), _ => Err(ValError::new(ErrorTypeDefaults::JsonType, self)), } } - fn strict_str(&'a self) -> ValResult> { + fn exact_str(&'a self) -> ValResult> { match self { - JsonInput::String(s) => Ok(s.as_str().into()), + JsonValue::Str(s) => Ok(s.as_str().into()), _ => Err(ValError::new(ErrorTypeDefaults::StringType, self)), } } - fn lax_str(&'a self, coerce_numbers_to_str: bool) -> ValResult> { + + fn validate_str( + &'a self, + strict: bool, + coerce_numbers_to_str: bool, + ) -> ValResult>> { + // Justification for `strict` instead of `exact` is that in JSON strings can also + // represent other datatypes such as UUID and date more exactly, so string is a + // converting input + // TODO: in V3 we may want to make JSON str always win if in union, for consistency, + // see https://github.com/pydantic/pydantic-core/pull/867#discussion_r1386582501 match self { - JsonInput::String(s) => Ok(s.as_str().into()), - JsonInput::BigInt(v) if coerce_numbers_to_str => Ok(v.to_string().into()), - JsonInput::Float(v) if coerce_numbers_to_str => Ok(v.to_string().into()), - JsonInput::Int(v) if coerce_numbers_to_str => Ok(v.to_string().into()), - JsonInput::Uint(v) if coerce_numbers_to_str => Ok(v.to_string().into()), + JsonValue::Str(s) => Ok(ValidationMatch::strict(s.as_str().into())), + JsonValue::Int(i) if !strict && coerce_numbers_to_str => Ok(ValidationMatch::lax(i.to_string().into())), + JsonValue::BigInt(b) if !strict && coerce_numbers_to_str => Ok(ValidationMatch::lax(b.to_string().into())), + JsonValue::Float(f) if !strict && coerce_numbers_to_str => Ok(ValidationMatch::lax(f.to_string().into())), _ => Err(ValError::new(ErrorTypeDefaults::StringType, self)), } } - fn validate_bytes(&'a self, _strict: bool) -> ValResult> { + fn validate_bytes(&'a self, _strict: bool) -> ValResult>> { match self { - JsonInput::String(s) => Ok(s.as_bytes().into()), + JsonValue::Str(s) => Ok(ValidationMatch::strict(s.as_bytes().into())), _ => Err(ValError::new(ErrorTypeDefaults::BytesType, self)), } } - #[cfg_attr(has_coverage_attribute, coverage(off))] - fn strict_bytes(&'a self) -> ValResult> { - self.validate_bytes(false) - } - fn strict_bool(&self) -> ValResult { - match self { - JsonInput::Bool(b) => Ok(*b), - _ => Err(ValError::new(ErrorTypeDefaults::BoolType, self)), - } - } - fn lax_bool(&self) -> ValResult { + fn validate_bool(&self, strict: bool) -> ValResult<'_, ValidationMatch> { match self { - JsonInput::Bool(b) => Ok(*b), - JsonInput::String(s) => str_as_bool(self, s), - JsonInput::Int(int) => int_as_bool(self, *int), - JsonInput::Float(float) => match float_as_int(self, *float) { + JsonValue::Bool(b) => Ok(ValidationMatch::exact(*b)), + JsonValue::Str(s) if !strict => str_as_bool(self, s).map(ValidationMatch::lax), + JsonValue::Int(int) if !strict => int_as_bool(self, *int).map(ValidationMatch::lax), + JsonValue::Float(float) if !strict => match float_as_int(self, *float) { Ok(int) => int .as_bool() - .ok_or_else(|| ValError::new(ErrorTypeDefaults::BoolParsing, self)), + .ok_or_else(|| ValError::new(ErrorTypeDefaults::BoolParsing, self)) + .map(ValidationMatch::lax), _ => Err(ValError::new(ErrorTypeDefaults::BoolType, self)), }, _ => Err(ValError::new(ErrorTypeDefaults::BoolType, self)), } } - fn strict_int(&'a self) -> ValResult> { + fn validate_int(&'a self, strict: bool) -> ValResult<'a, ValidationMatch>> { match self { - JsonInput::Int(i) => Ok(EitherInt::I64(*i)), - JsonInput::Uint(u) => Ok(EitherInt::U64(*u)), - JsonInput::BigInt(b) => Ok(EitherInt::BigInt(b.clone())), - _ => Err(ValError::new(ErrorTypeDefaults::IntType, self)), - } - } - fn lax_int(&'a self) -> ValResult> { - match self { - JsonInput::Bool(b) => match *b { - true => Ok(EitherInt::I64(1)), - false => Ok(EitherInt::I64(0)), - }, - JsonInput::Int(i) => Ok(EitherInt::I64(*i)), - JsonInput::Uint(u) => Ok(EitherInt::U64(*u)), - JsonInput::BigInt(b) => Ok(EitherInt::BigInt(b.clone())), - JsonInput::Float(f) => float_as_int(self, *f), - JsonInput::String(str) => str_as_int(self, str), + JsonValue::Int(i) => Ok(ValidationMatch::exact(EitherInt::I64(*i))), + JsonValue::BigInt(b) => Ok(ValidationMatch::exact(EitherInt::BigInt(b.clone()))), + JsonValue::Bool(b) if !strict => Ok(ValidationMatch::lax(EitherInt::I64((*b).into()))), + JsonValue::Float(f) if !strict => float_as_int(self, *f).map(ValidationMatch::lax), + JsonValue::Str(str) if !strict => str_as_int(self, str).map(ValidationMatch::lax), _ => Err(ValError::new(ErrorTypeDefaults::IntType, self)), } } - fn ultra_strict_float(&'a self) -> ValResult> { + fn validate_float(&'a self, strict: bool) -> ValResult<'a, ValidationMatch>> { match self { - JsonInput::Float(f) => Ok(EitherFloat::F64(*f)), - _ => Err(ValError::new(ErrorTypeDefaults::FloatType, self)), - } - } - fn strict_float(&'a self) -> ValResult> { - match self { - JsonInput::Float(f) => Ok(EitherFloat::F64(*f)), - JsonInput::Int(i) => Ok(EitherFloat::F64(*i as f64)), - JsonInput::Uint(u) => Ok(EitherFloat::F64(*u as f64)), - _ => Err(ValError::new(ErrorTypeDefaults::FloatType, self)), - } - } - fn lax_float(&'a self) -> ValResult> { - match self { - JsonInput::Bool(b) => match *b { - true => Ok(EitherFloat::F64(1.0)), - false => Ok(EitherFloat::F64(0.0)), - }, - JsonInput::Float(f) => Ok(EitherFloat::F64(*f)), - JsonInput::Int(i) => Ok(EitherFloat::F64(*i as f64)), - JsonInput::Uint(u) => Ok(EitherFloat::F64(*u as f64)), - JsonInput::String(str) => str_as_float(self, str), + JsonValue::Float(f) => Ok(ValidationMatch::exact(EitherFloat::F64(*f))), + JsonValue::Int(i) => Ok(ValidationMatch::strict(EitherFloat::F64(*i as f64))), + JsonValue::Bool(b) if !strict => Ok(ValidationMatch::lax(EitherFloat::F64(if *b { 1.0 } else { 0.0 }))), + JsonValue::Str(str) if !strict => str_as_float(self, str).map(ValidationMatch::lax), _ => Err(ValError::new(ErrorTypeDefaults::FloatType, self)), } } fn strict_decimal(&'a self, py: Python<'a>) -> ValResult<&'a PyAny> { match self { - JsonInput::Float(f) => create_decimal(PyString::new(py, &f.to_string()), self, py), + JsonValue::Float(f) => create_decimal(PyString::new(py, &f.to_string()), self, py), - JsonInput::String(..) | JsonInput::Int(..) | JsonInput::Uint(..) | JsonInput::BigInt(..) => { + JsonValue::Str(..) | JsonValue::Int(..) | JsonValue::BigInt(..) => { create_decimal(self.to_object(py).into_ref(py), self, py) } _ => Err(ValError::new(ErrorTypeDefaults::DecimalType, self)), @@ -195,7 +168,7 @@ impl<'a> Input<'a> for JsonInput { fn validate_dict(&'a self, _strict: bool) -> ValResult> { match self { - JsonInput::Object(dict) => Ok(dict.into()), + JsonValue::Object(dict) => Ok(dict.into()), _ => Err(ValError::new(ErrorTypeDefaults::DictType, self)), } } @@ -206,7 +179,7 @@ impl<'a> Input<'a> for JsonInput { fn validate_list(&'a self, _strict: bool) -> ValResult> { match self { - JsonInput::Array(a) => Ok(GenericIterable::JsonArray(a)), + JsonValue::Array(a) => Ok(GenericIterable::JsonArray(a)), _ => Err(ValError::new(ErrorTypeDefaults::ListType, self)), } } @@ -218,7 +191,7 @@ impl<'a> Input<'a> for JsonInput { fn validate_tuple(&'a self, _strict: bool) -> ValResult> { // just as in set's case, List has to be allowed match self { - JsonInput::Array(a) => Ok(GenericIterable::JsonArray(a)), + JsonValue::Array(a) => Ok(GenericIterable::JsonArray(a)), _ => Err(ValError::new(ErrorTypeDefaults::TupleType, self)), } } @@ -230,7 +203,7 @@ impl<'a> Input<'a> for JsonInput { fn validate_set(&'a self, _strict: bool) -> ValResult> { // we allow a list here since otherwise it would be impossible to create a set from JSON match self { - JsonInput::Array(a) => Ok(GenericIterable::JsonArray(a)), + JsonValue::Array(a) => Ok(GenericIterable::JsonArray(a)), _ => Err(ValError::new(ErrorTypeDefaults::SetType, self)), } } @@ -242,7 +215,7 @@ impl<'a> Input<'a> for JsonInput { fn validate_frozenset(&'a self, _strict: bool) -> ValResult> { // we allow a list here since otherwise it would be impossible to create a frozenset from JSON match self { - JsonInput::Array(a) => Ok(GenericIterable::JsonArray(a)), + JsonValue::Array(a) => Ok(GenericIterable::JsonArray(a)), _ => Err(ValError::new(ErrorTypeDefaults::FrozenSetType, self)), } } @@ -253,54 +226,44 @@ impl<'a> Input<'a> for JsonInput { fn extract_generic_iterable(&self) -> ValResult { match self { - JsonInput::Array(a) => Ok(GenericIterable::JsonArray(a)), - JsonInput::String(s) => Ok(GenericIterable::JsonString(s)), - JsonInput::Object(object) => Ok(GenericIterable::JsonObject(object)), + JsonValue::Array(a) => Ok(GenericIterable::JsonArray(a)), + JsonValue::Str(s) => Ok(GenericIterable::JsonString(s)), + JsonValue::Object(object) => Ok(GenericIterable::JsonObject(object)), _ => Err(ValError::new(ErrorTypeDefaults::IterableType, self)), } } fn validate_iter(&self) -> ValResult { match self { - JsonInput::Array(a) => Ok(a.clone().into()), - JsonInput::String(s) => Ok(string_to_vec(s).into()), - JsonInput::Object(object) => { + JsonValue::Array(a) => Ok(a.clone().into()), + JsonValue::Str(s) => Ok(string_to_vec(s).into()), + JsonValue::Object(object) => { // return keys iterator to match python's behavior - let keys: JsonArray = JsonArray::new(object.keys().map(|k| JsonInput::String(k.clone())).collect()); + let keys: JsonArray = JsonArray::new(object.keys().map(|k| JsonValue::Str(k.clone())).collect()); Ok(keys.into()) } _ => Err(ValError::new(ErrorTypeDefaults::IterableType, self)), } } - fn validate_date(&self, _strict: bool) -> ValResult { + fn validate_date(&self, _strict: bool) -> ValResult> { match self { - JsonInput::String(v) => bytes_as_date(self, v.as_bytes()), + JsonValue::Str(v) => bytes_as_date(self, v.as_bytes()).map(ValidationMatch::strict), _ => Err(ValError::new(ErrorTypeDefaults::DateType, self)), } } - // NO custom `lax_date` implementation, if strict_date fails, the validator will fallback to lax_datetime - // then check there's no remainder - #[cfg_attr(has_coverage_attribute, coverage(off))] - fn strict_date(&self) -> ValResult { - self.validate_date(false) - } - - fn strict_time( + fn validate_time( &self, + strict: bool, microseconds_overflow_behavior: MicrosecondsPrecisionOverflowBehavior, - ) -> ValResult { + ) -> ValResult> { match self { - JsonInput::String(v) => bytes_as_time(self, v.as_bytes(), microseconds_overflow_behavior), - _ => Err(ValError::new(ErrorTypeDefaults::TimeType, self)), - } - } - fn lax_time(&self, microseconds_overflow_behavior: MicrosecondsPrecisionOverflowBehavior) -> ValResult { - match self { - JsonInput::String(v) => bytes_as_time(self, v.as_bytes(), microseconds_overflow_behavior), - JsonInput::Int(v) => int_as_time(self, *v, 0), - JsonInput::Float(v) => float_as_time(self, *v), - JsonInput::BigInt(_) => Err(ValError::new( + JsonValue::Str(v) => { + bytes_as_time(self, v.as_bytes(), microseconds_overflow_behavior).map(ValidationMatch::strict) + } + JsonValue::Int(v) if !strict => int_as_time(self, *v, 0).map(ValidationMatch::lax), + JsonValue::Float(v) if !strict => float_as_time(self, *v).map(ValidationMatch::lax), + JsonValue::BigInt(_) if !strict => Err(ValError::new( ErrorType::TimeParsing { error: Cow::Borrowed( speedate::ParseError::TimeTooLarge @@ -315,64 +278,56 @@ impl<'a> Input<'a> for JsonInput { } } - fn strict_datetime( - &self, - microseconds_overflow_behavior: MicrosecondsPrecisionOverflowBehavior, - ) -> ValResult { - match self { - JsonInput::String(v) => bytes_as_datetime(self, v.as_bytes(), microseconds_overflow_behavior), - _ => Err(ValError::new(ErrorTypeDefaults::DatetimeType, self)), - } - } - fn lax_datetime( + fn validate_datetime( &self, - microseconds_overflow_behavior: MicrosecondsPrecisionOverflowBehavior, - ) -> ValResult { + strict: bool, + microseconds_overflow_behavior: speedate::MicrosecondsPrecisionOverflowBehavior, + ) -> ValResult> { match self { - JsonInput::String(v) => bytes_as_datetime(self, v.as_bytes(), microseconds_overflow_behavior), - JsonInput::Int(v) => int_as_datetime(self, *v, 0), - JsonInput::Float(v) => float_as_datetime(self, *v), + JsonValue::Str(v) => { + bytes_as_datetime(self, v.as_bytes(), microseconds_overflow_behavior).map(ValidationMatch::strict) + } + JsonValue::Int(v) if !strict => int_as_datetime(self, *v, 0).map(ValidationMatch::lax), + JsonValue::Float(v) if !strict => float_as_datetime(self, *v).map(ValidationMatch::lax), _ => Err(ValError::new(ErrorTypeDefaults::DatetimeType, self)), } } - fn strict_timedelta( + fn validate_timedelta( &self, - microseconds_overflow_behavior: MicrosecondsPrecisionOverflowBehavior, - ) -> ValResult { + strict: bool, + microseconds_overflow_behavior: speedate::MicrosecondsPrecisionOverflowBehavior, + ) -> ValResult> { match self { - JsonInput::String(v) => bytes_as_timedelta(self, v.as_bytes(), microseconds_overflow_behavior), - _ => Err(ValError::new(ErrorTypeDefaults::TimeDeltaType, self)), - } - } - fn lax_timedelta( - &self, - microseconds_overflow_behavior: MicrosecondsPrecisionOverflowBehavior, - ) -> ValResult { - match self { - JsonInput::String(v) => bytes_as_timedelta(self, v.as_bytes(), microseconds_overflow_behavior), - JsonInput::Int(v) => Ok(int_as_duration(self, *v)?.into()), - JsonInput::Float(v) => Ok(float_as_duration(self, *v)?.into()), + JsonValue::Str(v) => { + bytes_as_timedelta(self, v.as_bytes(), microseconds_overflow_behavior).map(ValidationMatch::strict) + } + JsonValue::Int(v) if !strict => { + int_as_duration(self, *v).map(|duration| ValidationMatch::lax(duration.into())) + } + JsonValue::Float(v) if !strict => { + float_as_duration(self, *v).map(|duration| ValidationMatch::lax(duration.into())) + } _ => Err(ValError::new(ErrorTypeDefaults::TimeDeltaType, self)), } } } -impl BorrowInput for &'_ JsonInput { - type Input<'a> = JsonInput where Self: 'a; +impl BorrowInput for &'_ JsonValue { + type Input<'a> = JsonValue where Self: 'a; fn borrow_input(&self) -> &Self::Input<'_> { self } } -/// TODO: it would be good to get JsonInput and StringMapping string variants to go through this -/// implementation -/// Required for Dict keys so the string can behave like an Input -impl<'a> Input<'a> for String { +impl AsLocItem for String { fn as_loc_item(&self) -> LocItem { self.to_string().into() } +} +/// Required for JSON Object keys so the string can behave like an Input +impl<'a> Input<'a> for String { fn as_error_value(&'a self) -> InputValue<'a> { InputValue::String(self) } @@ -398,34 +353,40 @@ impl<'a> Input<'a> for String { )) } - fn parse_json(&'a self) -> ValResult<'a, JsonInput> { - serde_json::from_str(self.as_str()).map_err(|e| map_json_err(self, e)) + fn parse_json(&'a self) -> ValResult<'a, JsonValue> { + JsonValue::parse(self.as_bytes(), true).map_err(|e| map_json_err(self, e)) } - fn strict_str(&'a self) -> ValResult> { - Ok(self.as_str().into()) + fn validate_str( + &'a self, + _strict: bool, + _coerce_numbers_to_str: bool, + ) -> ValResult>> { + // Justification for `strict` instead of `exact` is that in JSON strings can also + // represent other datatypes such as UUID and date more exactly, so string is a + // converting input + // TODO: in V3 we may want to make JSON str always win if in union, for consistency, + // see https://github.com/pydantic/pydantic-core/pull/867#discussion_r1386582501 + Ok(ValidationMatch::strict(self.as_str().into())) } - fn strict_bytes(&'a self) -> ValResult> { - Ok(self.as_bytes().into()) + fn validate_bytes(&'a self, _strict: bool) -> ValResult>> { + Ok(ValidationMatch::strict(self.as_bytes().into())) } - fn strict_bool(&self) -> ValResult { - str_as_bool(self, self) + fn validate_bool(&self, _strict: bool) -> ValResult<'_, ValidationMatch> { + str_as_bool(self, self).map(ValidationMatch::lax) } - fn strict_int(&'a self) -> ValResult> { + fn validate_int(&'a self, _strict: bool) -> ValResult<'a, ValidationMatch>> { match self.parse() { - Ok(i) => Ok(EitherInt::I64(i)), + Ok(i) => Ok(ValidationMatch::lax(EitherInt::I64(i))), Err(_) => Err(ValError::new(ErrorTypeDefaults::IntParsing, self)), } } - fn ultra_strict_float(&'a self) -> ValResult> { - self.strict_float() - } - fn strict_float(&'a self) -> ValResult> { - str_as_float(self, self) + fn validate_float(&'a self, _strict: bool) -> ValResult<'a, ValidationMatch>> { + str_as_float(self, self).map(ValidationMatch::lax) } fn strict_decimal(&'a self, py: Python<'a>) -> ValResult<&'a PyAny> { @@ -465,29 +426,32 @@ impl<'a> Input<'a> for String { Ok(string_to_vec(self).into()) } - fn strict_date(&self) -> ValResult { - bytes_as_date(self, self.as_bytes()) + fn validate_date(&self, _strict: bool) -> ValResult> { + bytes_as_date(self, self.as_bytes()).map(ValidationMatch::lax) } - fn strict_time( + fn validate_time( &self, + _strict: bool, microseconds_overflow_behavior: MicrosecondsPrecisionOverflowBehavior, - ) -> ValResult { - bytes_as_time(self, self.as_bytes(), microseconds_overflow_behavior) + ) -> ValResult> { + bytes_as_time(self, self.as_bytes(), microseconds_overflow_behavior).map(ValidationMatch::lax) } - fn strict_datetime( + fn validate_datetime( &self, + _strict: bool, microseconds_overflow_behavior: MicrosecondsPrecisionOverflowBehavior, - ) -> ValResult { - bytes_as_datetime(self, self.as_bytes(), microseconds_overflow_behavior) + ) -> ValResult> { + bytes_as_datetime(self, self.as_bytes(), microseconds_overflow_behavior).map(ValidationMatch::lax) } - fn strict_timedelta( + fn validate_timedelta( &self, + _strict: bool, microseconds_overflow_behavior: MicrosecondsPrecisionOverflowBehavior, - ) -> ValResult { - bytes_as_timedelta(self, self.as_bytes(), microseconds_overflow_behavior) + ) -> ValResult> { + bytes_as_timedelta(self, self.as_bytes(), microseconds_overflow_behavior).map(ValidationMatch::lax) } } @@ -504,3 +468,7 @@ impl BorrowInput for String { self } } + +fn string_to_vec(s: &str) -> JsonArray { + JsonArray::new(s.chars().map(|c| JsonValue::Str(c.to_string())).collect()) +} diff --git a/src/input/input_python.rs b/src/input/input_python.rs index cf84c5517..90d2c2a8b 100644 --- a/src/input/input_python.rs +++ b/src/input/input_python.rs @@ -8,12 +8,15 @@ use pyo3::types::{ }; #[cfg(not(PyPy))] use pyo3::types::{PyDictItems, PyDictKeys, PyDictValues}; -use pyo3::{intern, AsPyPointer, PyTypeInfo}; +use pyo3::{intern, PyTypeInfo}; + +use jiter::JsonValue; use speedate::MicrosecondsPrecisionOverflowBehavior; -use crate::errors::{ErrorType, ErrorTypeDefaults, InputValue, LocItem, ValError, ValResult}; +use crate::errors::{AsLocItem, ErrorType, ErrorTypeDefaults, InputValue, LocItem, ValError, ValResult}; use crate::tools::{extract_i64, safe_repr}; use crate::validators::decimal::{create_decimal, get_decimal_type}; +use crate::validators::Exactness; use crate::{ArgsKwargs, PyMultiHostUrl, PyUrl}; use super::datetime::{ @@ -21,10 +24,14 @@ use super::datetime::{ float_as_duration, float_as_time, int_as_datetime, int_as_duration, int_as_time, EitherDate, EitherDateTime, EitherTime, }; -use super::shared::{decimal_as_int, float_as_int, int_as_bool, map_json_err, str_as_bool, str_as_float, str_as_int}; +use super::return_enums::ValidationMatch; +use super::shared::{ + decimal_as_int, float_as_int, get_enum_meta_object, int_as_bool, map_json_err, str_as_bool, str_as_float, + str_as_int, +}; use super::{ py_string_str, BorrowInput, EitherBytes, EitherFloat, EitherInt, EitherString, EitherTimedelta, GenericArguments, - GenericIterable, GenericIterator, GenericMapping, Input, JsonInput, PyArgs, + GenericIterable, GenericIterator, GenericMapping, Input, PyArgs, }; #[cfg(not(PyPy))] @@ -32,7 +39,7 @@ macro_rules! extract_dict_keys { ($py:expr, $obj:ident) => { $obj.downcast::() .ok() - .map(|v| PyIterator::from_object($py, v).unwrap()) + .map(|v| PyIterator::from_object(v).unwrap()) }; } @@ -40,7 +47,7 @@ macro_rules! extract_dict_keys { macro_rules! extract_dict_keys { ($py:expr, $obj:ident) => { if is_dict_keys_type($obj) { - Some(PyIterator::from_object($py, $obj).unwrap()) + Some(PyIterator::from_object($obj).unwrap()) } else { None } @@ -52,7 +59,7 @@ macro_rules! extract_dict_values { ($py:expr, $obj:ident) => { $obj.downcast::() .ok() - .map(|v| PyIterator::from_object($py, v).unwrap()) + .map(|v| PyIterator::from_object(v).unwrap()) }; } @@ -60,7 +67,7 @@ macro_rules! extract_dict_values { macro_rules! extract_dict_values { ($py:expr, $obj:ident) => { if is_dict_values_type($obj) { - Some(PyIterator::from_object($py, $obj).unwrap()) + Some(PyIterator::from_object($obj).unwrap()) } else { None } @@ -72,7 +79,7 @@ macro_rules! extract_dict_items { ($py:expr, $obj:ident) => { $obj.downcast::() .ok() - .map(|v| PyIterator::from_object($py, v).unwrap()) + .map(|v| PyIterator::from_object(v).unwrap()) }; } @@ -80,14 +87,14 @@ macro_rules! extract_dict_items { macro_rules! extract_dict_items { ($py:expr, $obj:ident) => { if is_dict_items_type($obj) { - Some(PyIterator::from_object($py, $obj).unwrap()) + Some(PyIterator::from_object($obj).unwrap()) } else { None } }; } -impl<'a> Input<'a> for PyAny { +impl AsLocItem for PyAny { fn as_loc_item(&self) -> LocItem { if let Ok(py_str) = self.downcast::() { py_str.to_string_lossy().as_ref().into() @@ -97,7 +104,9 @@ impl<'a> Input<'a> for PyAny { safe_repr(self).to_string().into() } } +} +impl<'a> Input<'a> for PyAny { fn as_error_value(&'a self) -> InputValue<'a> { InputValue::PyAny(self) } @@ -180,35 +189,75 @@ impl<'a> Input<'a> for PyAny { } } - fn parse_json(&'a self) -> ValResult<'a, JsonInput> { - if let Ok(py_bytes) = self.downcast::() { - serde_json::from_slice(py_bytes.as_bytes()).map_err(|e| map_json_err(self, e)) + fn parse_json(&'a self) -> ValResult<'a, JsonValue> { + let bytes = if let Ok(py_bytes) = self.downcast::() { + py_bytes.as_bytes() } else if let Ok(py_str) = self.downcast::() { let str = py_string_str(py_str)?; - serde_json::from_str(str).map_err(|e| map_json_err(self, e)) + str.as_bytes() } else if let Ok(py_byte_array) = self.downcast::() { // Safety: from_slice does not run arbitrary Python code and the GIL is held so the - // bytes array will not be mutated while from_slice is reading it - serde_json::from_slice(unsafe { py_byte_array.as_bytes() }).map_err(|e| map_json_err(self, e)) + // bytes array will not be mutated while `JsonValue::parse` is reading it + unsafe { py_byte_array.as_bytes() } } else { - Err(ValError::new(ErrorTypeDefaults::JsonType, self)) - } + return Err(ValError::new(ErrorTypeDefaults::JsonType, self)); + }; + JsonValue::parse(bytes, true).map_err(|e| map_json_err(self, e)) } - fn strict_str(&'a self) -> ValResult> { - if let Ok(py_str) = PyString::try_from_exact(self) { - Ok(py_str.into()) + fn validate_str( + &'a self, + strict: bool, + coerce_numbers_to_str: bool, + ) -> ValResult>> { + if let Ok(py_str) = self.downcast_exact::() { + return Ok(ValidationMatch::exact(py_str.into())); } else if let Ok(py_str) = self.downcast::() { // force to a rust string to make sure behavior is consistent whether or not we go via a // rust string in StrConstrainedValidator - e.g. to_lower - Ok(py_string_str(py_str)?.into()) - } else { - Err(ValError::new(ErrorTypeDefaults::StringType, self)) + return Ok(ValidationMatch::strict(py_string_str(py_str)?.into())); + } + + 'lax: { + if !strict { + return if let Ok(bytes) = self.downcast::() { + match from_utf8(bytes.as_bytes()) { + Ok(str) => Ok(str.into()), + Err(_) => Err(ValError::new(ErrorTypeDefaults::StringUnicode, self)), + } + } else if let Ok(py_byte_array) = self.downcast::() { + // Safety: the gil is held while from_utf8 is running so py_byte_array is not mutated, + // and we immediately copy the bytes into a new Python string + match from_utf8(unsafe { py_byte_array.as_bytes() }) { + // Why Python not Rust? to avoid an unnecessary allocation on the Rust side, the + // final output needs to be Python anyway. + Ok(s) => Ok(PyString::new(self.py(), s).into()), + Err(_) => Err(ValError::new(ErrorTypeDefaults::StringUnicode, self)), + } + } else if coerce_numbers_to_str && !PyBool::is_exact_type_of(self) && { + let py = self.py(); + let decimal_type: Py = get_decimal_type(py); + + // only allow int, float, and decimal (not bool) + self.is_instance_of::() + || self.is_instance_of::() + || self.is_instance(decimal_type.as_ref(py)).unwrap_or_default() + } { + Ok(self.str()?.into()) + } else if let Some(enum_val) = maybe_as_enum(self) { + Ok(enum_val.str()?.into()) + } else { + break 'lax; + } + .map(ValidationMatch::lax); + } } + + Err(ValError::new(ErrorTypeDefaults::StringType, self)) } fn exact_str(&'a self) -> ValResult> { - if let Ok(py_str) = PyString::try_from_exact(self) { + if let Ok(py_str) = ::try_from_exact(self) { Ok(EitherString::Py(py_str)) } else { Err(ValError::new(ErrorTypeDefaults::IntType, self)) @@ -223,161 +272,118 @@ impl<'a> Input<'a> for PyAny { } } - fn lax_str(&'a self, coerce_numbers_to_str: bool) -> ValResult> { - if let Ok(py_str) = ::try_from_exact(self) { - Ok(py_str.into()) - } else if let Ok(py_str) = self.downcast::() { - // force to a rust string to make sure behaviour is consistent whether or not we go via a - // rust string in StrConstrainedValidator - e.g. to_lower - Ok(py_string_str(py_str)?.into()) - } else if let Ok(bytes) = self.downcast::() { - let str = match from_utf8(bytes.as_bytes()) { - Ok(s) => s, - Err(_) => return Err(ValError::new(ErrorTypeDefaults::StringUnicode, self)), - }; - Ok(str.into()) - } else if let Ok(py_byte_array) = self.downcast::() { - // Safety: the gil is held while from_utf8 is running so py_byte_array is not mutated, - // and we immediately copy the bytes into a new Python string - let s = match from_utf8(unsafe { py_byte_array.as_bytes() }) { - // Why Python not Rust? to avoid an unnecessary allocation on the Rust side, the - // final output needs to be Python anyway. - Ok(s) => PyString::new(self.py(), s), - Err(_) => return Err(ValError::new(ErrorTypeDefaults::StringUnicode, self)), - }; - Ok(s.into()) - } else if coerce_numbers_to_str && { - let py = self.py(); - let decimal_type: Py = get_decimal_type(py); - - self.is_instance_of::() - || self.is_instance_of::() - || self.is_instance(decimal_type.as_ref(py)).unwrap_or_default() - } { - Ok(self.str()?.into()) - } else { - Err(ValError::new(ErrorTypeDefaults::StringType, self)) - } - } - - fn strict_bytes(&'a self) -> ValResult> { - if let Ok(py_bytes) = self.downcast::() { - Ok(py_bytes.into()) - } else { - Err(ValError::new(ErrorTypeDefaults::BytesType, self)) + fn validate_bytes(&'a self, strict: bool) -> ValResult>> { + if let Ok(py_bytes) = self.downcast_exact::() { + return Ok(ValidationMatch::exact(py_bytes.into())); + } else if let Ok(py_bytes) = self.downcast::() { + return Ok(ValidationMatch::strict(py_bytes.into())); } - } - fn lax_bytes(&'a self) -> ValResult> { - if let Ok(py_bytes) = self.downcast::() { - Ok(py_bytes.into()) - } else if let Ok(py_str) = self.downcast::() { - let str = py_string_str(py_str)?; - Ok(str.as_bytes().into()) - } else if let Ok(py_byte_array) = self.downcast::() { - Ok(py_byte_array.to_vec().into()) - } else { - Err(ValError::new(ErrorTypeDefaults::BytesType, self)) + 'lax: { + if !strict { + return if let Ok(py_str) = self.downcast::() { + let str = py_string_str(py_str)?; + Ok(str.as_bytes().into()) + } else if let Ok(py_byte_array) = self.downcast::() { + Ok(py_byte_array.to_vec().into()) + } else { + break 'lax; + } + .map(ValidationMatch::lax); + } } - } - fn strict_bool(&self) -> ValResult { - if let Ok(bool) = self.downcast::() { - Ok(bool.is_true()) - } else { - Err(ValError::new(ErrorTypeDefaults::BoolType, self)) - } + Err(ValError::new(ErrorTypeDefaults::BytesType, self)) } - fn lax_bool(&self) -> ValResult { + fn validate_bool(&self, strict: bool) -> ValResult<'_, ValidationMatch> { if let Ok(bool) = self.downcast::() { - Ok(bool.is_true()) - } else if let Some(cow_str) = maybe_as_string(self, ErrorTypeDefaults::BoolParsing)? { - str_as_bool(self, &cow_str) - } else if let Ok(int) = extract_i64(self) { - int_as_bool(self, int) - } else if let Ok(float) = self.extract::() { - match float_as_int(self, float) { - Ok(int) => int - .as_bool() - .ok_or_else(|| ValError::new(ErrorTypeDefaults::BoolParsing, self)), - _ => Err(ValError::new(ErrorTypeDefaults::BoolType, self)), + return Ok(ValidationMatch::exact(bool.is_true())); + } + + if !strict { + if let Some(cow_str) = maybe_as_string(self, ErrorTypeDefaults::BoolParsing)? { + return str_as_bool(self, &cow_str).map(ValidationMatch::lax); + } else if let Ok(int) = extract_i64(self) { + return int_as_bool(self, int).map(ValidationMatch::lax); + } else if let Ok(float) = self.extract::() { + if let Ok(int) = float_as_int(self, float) { + return int + .as_bool() + .ok_or_else(|| ValError::new(ErrorTypeDefaults::BoolParsing, self)) + .map(ValidationMatch::lax); + }; } - } else { - Err(ValError::new(ErrorTypeDefaults::BoolType, self)) } + + Err(ValError::new(ErrorTypeDefaults::BoolType, self)) } - fn strict_int(&'a self) -> ValResult> { - if PyInt::is_exact_type_of(self) { - Ok(EitherInt::Py(self)) - } else if PyInt::is_type_of(self) { + fn validate_int(&'a self, strict: bool) -> ValResult<'a, ValidationMatch>> { + if self.is_exact_instance_of::() { + return Ok(ValidationMatch::exact(EitherInt::Py(self))); + } else if self.is_instance_of::() { // bools are a subclass of int, so check for bool type in this specific case - if PyBool::is_exact_type_of(self) { - Err(ValError::new(ErrorTypeDefaults::IntType, self)) + let exactness = if self.is_instance_of::() { + if strict { + return Err(ValError::new(ErrorTypeDefaults::IntType, self)); + } + Exactness::Lax } else { - // force to an int to upcast to a pure python int - EitherInt::upcast(self) + Exactness::Strict + }; + + // force to an int to upcast to a pure python int + return EitherInt::upcast(self).map(|either_int| ValidationMatch::new(either_int, exactness)); + } + + 'lax: { + if !strict { + return if let Some(cow_str) = maybe_as_string(self, ErrorTypeDefaults::IntParsing)? { + str_as_int(self, &cow_str) + } else if self.is_exact_instance_of::() { + float_as_int(self, self.extract::()?) + } else if let Ok(decimal) = self.strict_decimal(self.py()) { + decimal_as_int(self.py(), self, decimal) + } else if let Ok(float) = self.extract::() { + float_as_int(self, float) + } else if let Some(enum_val) = maybe_as_enum(self) { + Ok(EitherInt::Py(enum_val)) + } else { + break 'lax; + } + .map(ValidationMatch::lax); } - } else { - Err(ValError::new(ErrorTypeDefaults::IntType, self)) } - } - fn lax_int(&'a self) -> ValResult> { - if PyInt::is_exact_type_of(self) { - Ok(EitherInt::Py(self)) - } else if let Some(cow_str) = maybe_as_string(self, ErrorTypeDefaults::IntParsing)? { - // Try strings before subclasses of int as that will be far more common - str_as_int(self, &cow_str) - } else if PyInt::is_type_of(self) { - // force to an int to upcast to a pure python int to maintain current behaviour - EitherInt::upcast(self) - } else if PyFloat::is_exact_type_of(self) { - float_as_int(self, self.extract::()?) - } else if let Ok(decimal) = self.strict_decimal(self.py()) { - decimal_as_int(self.py(), self, decimal) - } else if let Ok(float) = self.extract::() { - float_as_int(self, float) - } else { - Err(ValError::new(ErrorTypeDefaults::IntType, self)) - } + Err(ValError::new(ErrorTypeDefaults::IntType, self)) } - fn ultra_strict_float(&'a self) -> ValResult> { - if self.is_instance_of::() { - Err(ValError::new(ErrorTypeDefaults::FloatType, self)) - } else if let Ok(float) = self.downcast::() { - Ok(EitherFloat::Py(float)) - } else { - Err(ValError::new(ErrorTypeDefaults::FloatType, self)) + fn validate_float(&'a self, strict: bool) -> ValResult<'a, ValidationMatch>> { + if let Ok(float) = self.downcast_exact::() { + return Ok(ValidationMatch::exact(EitherFloat::Py(float))); } - } - fn strict_float(&'a self) -> ValResult> { - if let Ok(py_float) = self.downcast_exact::() { - Ok(EitherFloat::Py(py_float)) - } else if let Ok(float) = self.extract::() { - // bools are cast to floats as either 0.0 or 1.0, so check for bool type in this specific case - if (float == 0.0 || float == 1.0) && PyBool::is_exact_type_of(self) { - Err(ValError::new(ErrorTypeDefaults::FloatType, self)) - } else { - Ok(EitherFloat::F64(float)) + + if !strict { + if let Some(cow_str) = maybe_as_string(self, ErrorTypeDefaults::FloatParsing)? { + // checking for bytes and string is fast, so do this before isinstance(float) + return str_as_float(self, &cow_str).map(ValidationMatch::lax); } - } else { - Err(ValError::new(ErrorTypeDefaults::FloatType, self)) } - } - fn lax_float(&'a self) -> ValResult> { - if let Ok(py_float) = self.downcast_exact() { - Ok(EitherFloat::Py(py_float)) - } else if let Some(cow_str) = maybe_as_string(self, ErrorTypeDefaults::FloatParsing)? { - str_as_float(self, &cow_str) - } else if let Ok(float) = self.extract::() { - Ok(EitherFloat::F64(float)) - } else { - Err(ValError::new(ErrorTypeDefaults::FloatType, self)) + if let Ok(float) = self.extract::() { + let exactness = if self.is_instance_of::() { + if strict { + return Err(ValError::new(ErrorTypeDefaults::FloatType, self)); + } + Exactness::Lax + } else { + Exactness::Strict + }; + return Ok(ValidationMatch::new(EitherFloat::F64(float), exactness)); } + + Err(ValError::new(ErrorTypeDefaults::FloatType, self)) } fn strict_decimal(&'a self, py: Python<'a>) -> ValResult<&'a PyAny> { @@ -594,128 +600,136 @@ impl<'a> Input<'a> for PyAny { } } - fn strict_date(&self) -> ValResult { - if PyDateTime::is_type_of(self) { - // have to check if it's a datetime first, otherwise the line below converts to a date - Err(ValError::new(ErrorTypeDefaults::DateType, self)) - } else if let Ok(date) = self.downcast::() { - Ok(date.into()) - } else { - Err(ValError::new(ErrorTypeDefaults::DateType, self)) - } - } - - fn lax_date(&self) -> ValResult { - if PyDateTime::is_type_of(self) { + fn validate_date(&self, strict: bool) -> ValResult> { + if let Ok(date) = self.downcast_exact::() { + Ok(ValidationMatch::exact(date.into())) + } else if PyDateTime::is_type_of(self) { // have to check if it's a datetime first, otherwise the line below converts to a date // even if we later try coercion from a datetime, we don't want to return a datetime now Err(ValError::new(ErrorTypeDefaults::DateType, self)) } else if let Ok(date) = self.downcast::() { - Ok(date.into()) - } else if let Ok(py_str) = self.downcast::() { - let str = py_string_str(py_str)?; - bytes_as_date(self, str.as_bytes()) - } else if let Ok(py_bytes) = self.downcast::() { - bytes_as_date(self, py_bytes.as_bytes()) + Ok(ValidationMatch::strict(date.into())) + } else if let Some(bytes) = { + if strict { + None + } else if let Ok(py_str) = self.downcast::() { + let str = py_string_str(py_str)?; + Some(str.as_bytes()) + } else if let Ok(py_bytes) = self.downcast::() { + Some(py_bytes.as_bytes()) + } else { + None + } + } { + bytes_as_date(self, bytes).map(ValidationMatch::lax) } else { Err(ValError::new(ErrorTypeDefaults::DateType, self)) } } - fn strict_time( + fn validate_time( &self, - _microseconds_overflow_behavior: MicrosecondsPrecisionOverflowBehavior, - ) -> ValResult { - if let Ok(time) = self.downcast::() { - Ok(time.into()) - } else { - Err(ValError::new(ErrorTypeDefaults::TimeType, self)) + strict: bool, + microseconds_overflow_behavior: MicrosecondsPrecisionOverflowBehavior, + ) -> ValResult> { + if let Ok(time) = self.downcast_exact::() { + return Ok(ValidationMatch::exact(time.into())); + } else if let Ok(time) = self.downcast::() { + return Ok(ValidationMatch::strict(time.into())); + } + + 'lax: { + if !strict { + return if let Ok(py_str) = self.downcast::() { + let str = py_string_str(py_str)?; + bytes_as_time(self, str.as_bytes(), microseconds_overflow_behavior) + } else if let Ok(py_bytes) = self.downcast::() { + bytes_as_time(self, py_bytes.as_bytes(), microseconds_overflow_behavior) + } else if PyBool::is_exact_type_of(self) { + Err(ValError::new(ErrorTypeDefaults::TimeType, self)) + } else if let Ok(int) = extract_i64(self) { + int_as_time(self, int, 0) + } else if let Ok(float) = self.extract::() { + float_as_time(self, float) + } else { + break 'lax; + } + .map(ValidationMatch::lax); + } } - } - fn lax_time(&self, microseconds_overflow_behavior: MicrosecondsPrecisionOverflowBehavior) -> ValResult { - if let Ok(time) = self.downcast::() { - Ok(time.into()) - } else if let Ok(py_str) = self.downcast::() { - let str = py_string_str(py_str)?; - bytes_as_time(self, str.as_bytes(), microseconds_overflow_behavior) - } else if let Ok(py_bytes) = self.downcast::() { - bytes_as_time(self, py_bytes.as_bytes(), microseconds_overflow_behavior) - } else if PyBool::is_exact_type_of(self) { - Err(ValError::new(ErrorTypeDefaults::TimeType, self)) - } else if let Ok(int) = extract_i64(self) { - int_as_time(self, int, 0) - } else if let Ok(float) = self.extract::() { - float_as_time(self, float) - } else { - Err(ValError::new(ErrorTypeDefaults::TimeType, self)) - } + Err(ValError::new(ErrorTypeDefaults::TimeType, self)) } - fn strict_datetime( - &self, - _microseconds_overflow_behavior: MicrosecondsPrecisionOverflowBehavior, - ) -> ValResult { - if let Ok(dt) = self.downcast::() { - Ok(dt.into()) - } else { - Err(ValError::new(ErrorTypeDefaults::DatetimeType, self)) - } - } - - fn lax_datetime( + fn validate_datetime( &self, + strict: bool, microseconds_overflow_behavior: MicrosecondsPrecisionOverflowBehavior, - ) -> ValResult { - if let Ok(dt) = self.downcast::() { - Ok(dt.into()) - } else if let Ok(py_str) = self.downcast::() { - let str = py_string_str(py_str)?; - bytes_as_datetime(self, str.as_bytes(), microseconds_overflow_behavior) - } else if let Ok(py_bytes) = self.downcast::() { - bytes_as_datetime(self, py_bytes.as_bytes(), microseconds_overflow_behavior) - } else if PyBool::is_exact_type_of(self) { - Err(ValError::new(ErrorTypeDefaults::DatetimeType, self)) - } else if let Ok(int) = extract_i64(self) { - int_as_datetime(self, int, 0) - } else if let Ok(float) = self.extract::() { - float_as_datetime(self, float) - } else if let Ok(date) = self.downcast::() { - Ok(date_as_datetime(date)?) - } else { - Err(ValError::new(ErrorTypeDefaults::DatetimeType, self)) + ) -> ValResult> { + if let Ok(dt) = self.downcast_exact::() { + return Ok(ValidationMatch::exact(dt.into())); + } else if let Ok(dt) = self.downcast::() { + return Ok(ValidationMatch::strict(dt.into())); + } + + 'lax: { + if !strict { + return if let Ok(py_str) = self.downcast::() { + let str = py_string_str(py_str)?; + bytes_as_datetime(self, str.as_bytes(), microseconds_overflow_behavior) + } else if let Ok(py_bytes) = self.downcast::() { + bytes_as_datetime(self, py_bytes.as_bytes(), microseconds_overflow_behavior) + } else if PyBool::is_exact_type_of(self) { + Err(ValError::new(ErrorTypeDefaults::DatetimeType, self)) + } else if let Ok(int) = extract_i64(self) { + int_as_datetime(self, int, 0) + } else if let Ok(float) = self.extract::() { + float_as_datetime(self, float) + } else if let Ok(date) = self.downcast::() { + Ok(date_as_datetime(date)?) + } else { + break 'lax; + } + .map(ValidationMatch::lax); + } } - } - fn strict_timedelta( - &self, - _microseconds_overflow_behavior: MicrosecondsPrecisionOverflowBehavior, - ) -> ValResult { - if let Ok(either_dt) = EitherTimedelta::try_from(self) { - Ok(either_dt) - } else { - Err(ValError::new(ErrorTypeDefaults::TimeDeltaType, self)) - } + Err(ValError::new(ErrorTypeDefaults::DatetimeType, self)) } - fn lax_timedelta( + fn validate_timedelta( &self, + strict: bool, microseconds_overflow_behavior: MicrosecondsPrecisionOverflowBehavior, - ) -> ValResult { + ) -> ValResult> { if let Ok(either_dt) = EitherTimedelta::try_from(self) { - Ok(either_dt) - } else if let Ok(py_str) = self.downcast::() { - let str = py_string_str(py_str)?; - bytes_as_timedelta(self, str.as_bytes(), microseconds_overflow_behavior) - } else if let Ok(py_bytes) = self.downcast::() { - bytes_as_timedelta(self, py_bytes.as_bytes(), microseconds_overflow_behavior) - } else if let Ok(int) = extract_i64(self) { - Ok(int_as_duration(self, int)?.into()) - } else if let Ok(float) = self.extract::() { - Ok(float_as_duration(self, float)?.into()) - } else { - Err(ValError::new(ErrorTypeDefaults::TimeDeltaType, self)) + let exactness = if matches!(either_dt, EitherTimedelta::PyExact(_)) { + Exactness::Exact + } else { + Exactness::Strict + }; + return Ok(ValidationMatch::new(either_dt, exactness)); + } + + 'lax: { + if !strict { + return if let Ok(py_str) = self.downcast::() { + let str = py_string_str(py_str)?; + bytes_as_timedelta(self, str.as_bytes(), microseconds_overflow_behavior) + } else if let Ok(py_bytes) = self.downcast::() { + bytes_as_timedelta(self, py_bytes.as_bytes(), microseconds_overflow_behavior) + } else if let Ok(int) = extract_i64(self) { + Ok(int_as_duration(self, int)?.into()) + } else if let Ok(float) = self.extract::() { + Ok(float_as_duration(self, float)?.into()) + } else { + break 'lax; + } + .map(ValidationMatch::lax); + } } + + Err(ValError::new(ErrorTypeDefaults::TimeDeltaType, self)) } } @@ -758,6 +772,18 @@ fn maybe_as_string(v: &PyAny, unicode_error: ErrorType) -> ValResult Option<&PyAny> { + let py = v.py(); + let enum_meta_object = get_enum_meta_object(py); + let meta_type = v.get_type().get_type(); + if meta_type.is(&enum_meta_object) { + v.getattr(intern!(py, "value")).ok() + } else { + None + } +} + #[cfg(PyPy)] static DICT_KEYS_TYPE: pyo3::once_cell::GILOnceCell> = pyo3::once_cell::GILOnceCell::new(); diff --git a/src/input/input_string.rs b/src/input/input_string.rs index 72a32d897..e27ef6461 100644 --- a/src/input/input_string.rs +++ b/src/input/input_string.rs @@ -1,9 +1,10 @@ use pyo3::prelude::*; use pyo3::types::{PyDict, PyString}; +use jiter::JsonValue; use speedate::MicrosecondsPrecisionOverflowBehavior; -use crate::errors::{ErrorTypeDefaults, InputValue, LocItem, ValError, ValResult}; +use crate::errors::{AsLocItem, ErrorTypeDefaults, InputValue, LocItem, ValError, ValResult}; use crate::input::py_string_str; use crate::tools::safe_repr; use crate::validators::decimal::create_decimal; @@ -14,7 +15,7 @@ use super::datetime::{ use super::shared::{map_json_err, str_as_bool, str_as_float}; use super::{ BorrowInput, EitherBytes, EitherFloat, EitherInt, EitherString, EitherTimedelta, GenericArguments, GenericIterable, - GenericIterator, GenericMapping, Input, JsonInput, + GenericIterator, GenericMapping, Input, ValidationMatch, }; #[derive(Debug)] @@ -52,14 +53,16 @@ impl<'py> StringMapping<'py> { } } -impl<'a> Input<'a> for StringMapping<'a> { +impl AsLocItem for StringMapping<'_> { fn as_loc_item(&self) -> LocItem { match self { Self::String(s) => s.to_string_lossy().as_ref().into(), Self::Mapping(d) => safe_repr(d).to_string().into(), } } +} +impl<'a> Input<'a> for StringMapping<'a> { fn as_error_value(&'a self) -> InputValue<'a> { match self { Self::String(s) => s.as_error_value(), @@ -83,64 +86,54 @@ impl<'a> Input<'a> for StringMapping<'a> { } } - fn parse_json(&'a self) -> ValResult<'a, JsonInput> { + fn parse_json(&'a self) -> ValResult<'a, JsonValue> { match self { Self::String(s) => { let str = py_string_str(s)?; - serde_json::from_str(str).map_err(|e| map_json_err(self, e)) + JsonValue::parse(str.as_bytes(), true).map_err(|e| map_json_err(self, e)) } Self::Mapping(_) => Err(ValError::new(ErrorTypeDefaults::JsonType, self)), } } - fn strict_str(&'a self) -> ValResult> { + fn validate_str( + &'a self, + _strict: bool, + _coerce_numbers_to_str: bool, + ) -> ValResult>> { match self { - Self::String(s) => Ok((*s).into()), + Self::String(s) => Ok(ValidationMatch::strict((*s).into())), Self::Mapping(_) => Err(ValError::new(ErrorTypeDefaults::StringType, self)), } } - fn strict_bytes(&'a self) -> ValResult> { + fn validate_bytes(&'a self, _strict: bool) -> ValResult>> { match self { - Self::String(s) => py_string_str(s).map(|b| b.as_bytes().into()), + Self::String(s) => py_string_str(s).map(|b| ValidationMatch::strict(b.as_bytes().into())), Self::Mapping(_) => Err(ValError::new(ErrorTypeDefaults::BytesType, self)), } } - fn lax_bytes(&'a self) -> ValResult> { + fn validate_bool(&self, _strict: bool) -> ValResult<'_, ValidationMatch> { match self { - Self::String(s) => { - let str = py_string_str(s)?; - Ok(str.as_bytes().into()) - } - Self::Mapping(_) => Err(ValError::new(ErrorTypeDefaults::BytesType, self)), - } - } - - fn strict_bool(&self) -> ValResult { - match self { - Self::String(s) => str_as_bool(self, py_string_str(s)?), + Self::String(s) => str_as_bool(self, py_string_str(s)?).map(ValidationMatch::strict), Self::Mapping(_) => Err(ValError::new(ErrorTypeDefaults::BoolType, self)), } } - fn strict_int(&'a self) -> ValResult> { + fn validate_int(&'a self, _strict: bool) -> ValResult<'a, ValidationMatch>> { match self { Self::String(s) => match py_string_str(s)?.parse() { - Ok(i) => Ok(EitherInt::I64(i)), + Ok(i) => Ok(ValidationMatch::strict(EitherInt::I64(i))), Err(_) => Err(ValError::new(ErrorTypeDefaults::IntParsing, self)), }, Self::Mapping(_) => Err(ValError::new(ErrorTypeDefaults::IntType, self)), } } - fn ultra_strict_float(&'a self) -> ValResult> { - self.strict_float() - } - - fn strict_float(&'a self) -> ValResult> { + fn validate_float(&'a self, _strict: bool) -> ValResult<'a, ValidationMatch>> { match self { - Self::String(s) => str_as_float(self, py_string_str(s)?), + Self::String(s) => str_as_float(self, py_string_str(s)?).map(ValidationMatch::strict), Self::Mapping(_) => Err(ValError::new(ErrorTypeDefaults::FloatType, self)), } } @@ -183,39 +176,45 @@ impl<'a> Input<'a> for StringMapping<'a> { Err(ValError::new(ErrorTypeDefaults::IterableType, self)) } - fn strict_date(&self) -> ValResult { + fn validate_date(&self, _strict: bool) -> ValResult> { match self { - Self::String(s) => bytes_as_date(self, py_string_str(s)?.as_bytes()), + Self::String(s) => bytes_as_date(self, py_string_str(s)?.as_bytes()).map(ValidationMatch::strict), Self::Mapping(_) => Err(ValError::new(ErrorTypeDefaults::DateType, self)), } } - fn strict_time( + fn validate_time( &self, + _strict: bool, microseconds_overflow_behavior: MicrosecondsPrecisionOverflowBehavior, - ) -> ValResult { + ) -> ValResult> { match self { - Self::String(s) => bytes_as_time(self, py_string_str(s)?.as_bytes(), microseconds_overflow_behavior), + Self::String(s) => bytes_as_time(self, py_string_str(s)?.as_bytes(), microseconds_overflow_behavior) + .map(ValidationMatch::strict), Self::Mapping(_) => Err(ValError::new(ErrorTypeDefaults::TimeType, self)), } } - fn strict_datetime( + fn validate_datetime( &self, + _strict: bool, microseconds_overflow_behavior: MicrosecondsPrecisionOverflowBehavior, - ) -> ValResult { + ) -> ValResult> { match self { - Self::String(s) => bytes_as_datetime(self, py_string_str(s)?.as_bytes(), microseconds_overflow_behavior), + Self::String(s) => bytes_as_datetime(self, py_string_str(s)?.as_bytes(), microseconds_overflow_behavior) + .map(ValidationMatch::strict), Self::Mapping(_) => Err(ValError::new(ErrorTypeDefaults::DatetimeType, self)), } } - fn strict_timedelta( + fn validate_timedelta( &self, + _strict: bool, microseconds_overflow_behavior: MicrosecondsPrecisionOverflowBehavior, - ) -> ValResult { + ) -> ValResult> { match self { - Self::String(s) => bytes_as_timedelta(self, py_string_str(s)?.as_bytes(), microseconds_overflow_behavior), + Self::String(s) => bytes_as_timedelta(self, py_string_str(s)?.as_bytes(), microseconds_overflow_behavior) + .map(ValidationMatch::strict), Self::Mapping(_) => Err(ValError::new(ErrorTypeDefaults::TimeDeltaType, self)), } } diff --git a/src/input/mod.rs b/src/input/mod.rs index 22d774a8c..d7ca0a5bf 100644 --- a/src/input/mod.rs +++ b/src/input/mod.rs @@ -7,7 +7,6 @@ mod input_abstract; mod input_json; mod input_python; mod input_string; -mod parse_json; mod return_enums; mod shared; @@ -18,11 +17,10 @@ pub(crate) use datetime::{ }; pub(crate) use input_abstract::{BorrowInput, Input, InputType}; pub(crate) use input_string::StringMapping; -pub(crate) use parse_json::{JsonArray, JsonInput, JsonObject}; pub(crate) use return_enums::{ py_string_str, AttributesGenericIterator, DictGenericIterator, EitherBytes, EitherFloat, EitherInt, EitherString, GenericArguments, GenericIterable, GenericIterator, GenericMapping, Int, JsonArgs, JsonObjectGenericIterator, - MappingGenericIterator, PyArgs, StringMappingGenericIterator, + MappingGenericIterator, PyArgs, StringMappingGenericIterator, ValidationMatch, }; // Defined here as it's not exported by pyo3 diff --git a/src/input/parse_json.rs b/src/input/parse_json.rs deleted file mode 100644 index 20a107669..000000000 --- a/src/input/parse_json.rs +++ /dev/null @@ -1,222 +0,0 @@ -use std::fmt; -use std::sync::Arc; - -use num_bigint::BigInt; -use pyo3::prelude::*; -use pyo3::types::{PyDict, PyList}; -use serde::de::{Deserialize, DeserializeSeed, Error as SerdeError, MapAccess, SeqAccess, Visitor}; -use smallvec::SmallVec; - -use crate::lazy_index_map::LazyIndexMap; - -/// similar to serde `Value` but with int and float split -#[derive(Clone, Debug)] -pub enum JsonInput { - Null, - Bool(bool), - Int(i64), - BigInt(BigInt), - Uint(u64), - Float(f64), - String(String), - Array(JsonArray), - Object(JsonObject), -} -pub type JsonArray = Arc>; -pub type JsonObject = Arc>; - -impl ToPyObject for JsonInput { - fn to_object(&self, py: Python<'_>) -> PyObject { - match self { - Self::Null => py.None(), - Self::Bool(b) => b.into_py(py), - Self::Int(i) => i.into_py(py), - Self::BigInt(b) => b.to_object(py), - Self::Uint(i) => i.into_py(py), - Self::Float(f) => f.into_py(py), - Self::String(s) => s.into_py(py), - Self::Array(v) => PyList::new(py, v.iter().map(|v| v.to_object(py))).into_py(py), - Self::Object(o) => { - let dict = PyDict::new(py); - for (k, v) in o.iter() { - dict.set_item(k, v.to_object(py)).unwrap(); - } - dict.into_py(py) - } - } - } -} - -impl<'de> Deserialize<'de> for JsonInput { - fn deserialize(deserializer: D) -> Result - where - D: serde::Deserializer<'de>, - { - struct JsonVisitor; - - impl<'de> Visitor<'de> for JsonVisitor { - type Value = JsonInput; - - #[cfg_attr(has_coverage_attribute, coverage(off))] - fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { - formatter.write_str("any valid JSON value") - } - - fn visit_bool(self, value: bool) -> Result { - Ok(JsonInput::Bool(value)) - } - - fn visit_i64(self, value: i64) -> Result { - Ok(JsonInput::Int(value)) - } - - fn visit_u64(self, value: u64) -> Result { - match i64::try_from(value) { - Ok(i) => Ok(JsonInput::Int(i)), - Err(_) => Ok(JsonInput::Uint(value)), - } - } - - fn visit_f64(self, value: f64) -> Result { - Ok(JsonInput::Float(value)) - } - - fn visit_str(self, value: &str) -> Result - where - E: SerdeError, - { - Ok(JsonInput::String(value.to_string())) - } - - fn visit_string(self, value: String) -> Result { - Ok(JsonInput::String(value)) - } - - #[cfg_attr(has_coverage_attribute, coverage(off))] - fn visit_none(self) -> Result { - unreachable!() - } - - #[cfg_attr(has_coverage_attribute, coverage(off))] - fn visit_some(self, _: D) -> Result - where - D: serde::Deserializer<'de>, - { - unreachable!() - } - - fn visit_unit(self) -> Result { - Ok(JsonInput::Null) - } - - fn visit_seq(self, mut visitor: V) -> Result - where - V: SeqAccess<'de>, - { - let mut vec = SmallVec::new(); - - while let Some(elem) = visitor.next_element()? { - vec.push(elem); - } - - Ok(JsonInput::Array(JsonArray::new(vec))) - } - - fn visit_map(self, mut visitor: V) -> Result - where - V: MapAccess<'de>, - { - const SERDE_JSON_NUMBER: &str = "$serde_json::private::Number"; - match visitor.next_key_seed(KeyDeserializer)? { - Some(first_key) => { - let mut values = LazyIndexMap::new(); - let first_value = visitor.next_value()?; - - // serde_json will parse arbitrary precision numbers into a map - // structure with a "number" key and a String value - 'try_number: { - if first_key == SERDE_JSON_NUMBER { - // Just in case someone tries to actually store that key in a real map, - // keep parsing and continue as a map if so - - if let Some((key, value)) = visitor.next_entry::()? { - // Important to preserve order of the keys - values.insert(first_key, first_value); - values.insert(key, value); - break 'try_number; - } - - if let JsonInput::String(s) = &first_value { - // Normalize the string to either an int or float - let normalized = if s.chars().any(|c| c == '.' || c == 'E' || c == 'e') { - JsonInput::Float( - s.parse() - .map_err(|e| V::Error::custom(format!("expected a float: {e}")))?, - ) - } else if let Ok(i) = s.parse::() { - JsonInput::Int(i) - } else if let Ok(big) = s.parse::() { - JsonInput::BigInt(big) - } else { - // Failed to normalize, just throw it in the map and continue - values.insert(first_key, first_value); - break 'try_number; - }; - - return Ok(normalized); - }; - } else { - values.insert(first_key, first_value); - } - } - - while let Some((key, value)) = visitor.next_entry()? { - values.insert(key, value); - } - Ok(JsonInput::Object(Arc::new(values))) - } - None => Ok(JsonInput::Object(Arc::new(LazyIndexMap::new()))), - } - } - } - - deserializer.deserialize_any(JsonVisitor) - } -} - -struct KeyDeserializer; - -impl<'de> DeserializeSeed<'de> for KeyDeserializer { - type Value = String; - - fn deserialize(self, deserializer: D) -> Result - where - D: serde::Deserializer<'de>, - { - deserializer.deserialize_str(self) - } -} - -impl<'de> Visitor<'de> for KeyDeserializer { - type Value = String; - - #[cfg_attr(has_coverage_attribute, coverage(off))] - fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { - formatter.write_str("a string key") - } - - fn visit_str(self, s: &str) -> Result - where - E: serde::de::Error, - { - Ok(s.to_string()) - } - - #[cfg_attr(has_coverage_attribute, coverage(off))] - fn visit_string(self, _: String) -> Result - where - E: serde::de::Error, - { - unreachable!() - } -} diff --git a/src/input/return_enums.rs b/src/input/return_enums.rs index c492f40f0..56c7098df 100644 --- a/src/input/return_enums.rs +++ b/src/input/return_enums.rs @@ -4,6 +4,7 @@ use std::ops::Rem; use std::slice::Iter as SliceIter; use std::str::FromStr; +use jiter::{JsonArray, JsonObject, JsonValue}; use num_bigint::BigInt; use pyo3::exceptions::PyTypeError; @@ -13,7 +14,7 @@ use pyo3::types::{ PyByteArray, PyBytes, PyDict, PyFloat, PyFrozenSet, PyIterator, PyList, PyMapping, PySequence, PySet, PyString, PyTuple, }; -use pyo3::{ffi, intern, AsPyPointer, PyNativeType}; +use pyo3::{ffi, intern, PyNativeType}; #[cfg(not(PyPy))] use pyo3::types::PyFunction; @@ -23,12 +24,44 @@ use serde::{ser::Error, Serialize, Serializer}; use crate::errors::{py_err_string, ErrorType, ErrorTypeDefaults, InputValue, ValError, ValLineError, ValResult}; use crate::tools::py_err; -use crate::validators::{CombinedValidator, ValidationState, Validator}; +use crate::validators::{CombinedValidator, Exactness, ValidationState, Validator}; use super::input_string::StringMapping; -use super::parse_json::{JsonArray, JsonInput, JsonObject}; use super::{py_error_on_minusone, Input}; +pub struct ValidationMatch(T, Exactness); + +impl ValidationMatch { + pub const fn new(value: T, exactness: Exactness) -> Self { + Self(value, exactness) + } + + pub const fn exact(value: T) -> Self { + Self(value, Exactness::Exact) + } + + pub const fn strict(value: T) -> Self { + Self(value, Exactness::Strict) + } + + pub const fn lax(value: T) -> Self { + Self(value, Exactness::Lax) + } + + pub fn require_exact(self) -> Option { + (self.1 == Exactness::Exact).then_some(self.0) + } + + pub fn unpack(self, state: &mut ValidationState) -> T { + state.floor_exactness(self.1); + self.0 + } + + pub fn into_inner(self) -> T { + self.0 + } +} + /// Container for all the collections (sized iterable containers) types, which /// can mostly be converted to each other in lax mode. /// This mostly matches python's definition of `Collection`. @@ -50,7 +83,7 @@ pub enum GenericIterable<'a> { PyByteArray(&'a PyByteArray), Sequence(&'a PySequence), Iterator(&'a PyIterator), - JsonArray(&'a [JsonInput]), + JsonArray(&'a [JsonValue]), JsonObject(&'a JsonObject), JsonString(&'a String), } @@ -573,7 +606,7 @@ impl<'py> Iterator for AttributesGenericIterator<'py> { } pub struct JsonObjectGenericIterator<'py> { - object_iter: SliceIter<'py, (String, JsonInput)>, + object_iter: SliceIter<'py, (String, JsonValue)>, } impl<'py> JsonObjectGenericIterator<'py> { @@ -585,7 +618,7 @@ impl<'py> JsonObjectGenericIterator<'py> { } impl<'py> Iterator for JsonObjectGenericIterator<'py> { - type Item = ValResult<'py, (&'py String, &'py JsonInput)>; + type Item = ValResult<'py, (&'py String, &'py JsonValue)>; fn next(&mut self) -> Option { self.object_iter.next().map(|(key, value)| Ok((key, value))) @@ -653,7 +686,7 @@ pub struct GenericJsonIterator { } impl GenericJsonIterator { - pub fn next(&mut self, _py: Python) -> PyResult> { + pub fn next(&mut self, _py: Python) -> PyResult> { if self.index < self.array.len() { // panic here is impossible due to bounds check above; compiler should be // able to optimize it away even @@ -667,7 +700,7 @@ impl GenericJsonIterator { } pub fn input_as_error_value<'py>(&self, _py: Python<'py>) -> InputValue<'py> { - InputValue::JsonInput(JsonInput::Array(self.array.clone())) + InputValue::JsonInput(JsonValue::Array(self.array.clone())) } pub fn index(&self) -> usize { @@ -689,12 +722,12 @@ impl<'a> PyArgs<'a> { #[cfg_attr(debug_assertions, derive(Debug))] pub struct JsonArgs<'a> { - pub args: Option<&'a [JsonInput]>, + pub args: Option<&'a [JsonValue]>, pub kwargs: Option<&'a JsonObject>, } impl<'a> JsonArgs<'a> { - pub fn new(args: Option<&'a [JsonInput]>, kwargs: Option<&'a JsonObject>) -> Self { + pub fn new(args: Option<&'a [JsonValue]>, kwargs: Option<&'a JsonObject>) -> Self { Self { args, kwargs } } } diff --git a/src/input/shared.rs b/src/input/shared.rs index 1a8e2b61c..718210098 100644 --- a/src/input/shared.rs +++ b/src/input/shared.rs @@ -1,12 +1,26 @@ +use pyo3::sync::GILOnceCell; +use pyo3::{intern, Py, PyAny, Python, ToPyObject}; + +use jiter::JsonValueError; use num_bigint::BigInt; -use pyo3::{intern, PyAny, Python}; use crate::errors::{ErrorType, ErrorTypeDefaults, ValError, ValResult}; -use super::parse_json::{JsonArray, JsonInput}; use super::{EitherFloat, EitherInt, Input}; +static ENUM_META_OBJECT: GILOnceCell> = GILOnceCell::new(); + +pub fn get_enum_meta_object(py: Python) -> Py { + ENUM_META_OBJECT + .get_or_init(py, || { + py.import(intern!(py, "enum")) + .and_then(|enum_module| enum_module.getattr(intern!(py, "EnumMeta"))) + .unwrap() + .to_object(py) + }) + .clone() +} -pub fn map_json_err<'a>(input: &'a impl Input<'a>, error: serde_json::Error) -> ValError<'a> { +pub fn map_json_err<'a>(input: &'a impl Input<'a>, error: JsonValueError) -> ValError<'a> { ValError::new( ErrorType::JsonInvalid { error: error.to_string(), @@ -150,7 +164,3 @@ pub fn decimal_as_int<'a>(py: Python, input: &'a impl Input<'a>, decimal: &'a Py } Ok(EitherInt::Py(numerator)) } - -pub fn string_to_vec(s: &str) -> JsonArray { - JsonArray::new(s.chars().map(|c| JsonInput::String(c.to_string())).collect()) -} diff --git a/src/lazy_index_map.rs b/src/lazy_index_map.rs deleted file mode 100644 index c5621f877..000000000 --- a/src/lazy_index_map.rs +++ /dev/null @@ -1,63 +0,0 @@ -use std::borrow::Borrow; -use std::cmp::{Eq, PartialEq}; -use std::fmt::Debug; -use std::hash::Hash; -use std::slice::Iter as SliceIter; -use std::sync::OnceLock; - -use ahash::AHashMap; -use smallvec::SmallVec; - -#[derive(Debug, Clone, Default)] -pub struct LazyIndexMap { - vec: SmallVec<[(K, V); 8]>, - map: OnceLock>, -} - -/// Like [IndexMap](https://docs.rs/indexmap/latest/indexmap/) but only builds the lookup map when it's needed. -impl LazyIndexMap -where - K: Clone + Debug + Eq + Hash, - V: Debug, -{ - pub fn new() -> Self { - Self { - vec: SmallVec::new(), - map: OnceLock::new(), - } - } - - pub fn insert(&mut self, key: K, value: V) { - if let Some(map) = self.map.get_mut() { - map.insert(key.clone(), self.vec.len()); - } - self.vec.push((key, value)); - } - - pub fn len(&self) -> usize { - self.vec.len() - } - - pub fn get(&self, key: &Q) -> Option<&V> - where - K: Borrow + PartialEq, - Q: Hash + Eq, - { - let map = self.map.get_or_init(|| { - self.vec - .iter() - .enumerate() - .map(|(index, (key, _))| (key.clone(), index)) - .collect() - }); - map.get(key).map(|&i| &self.vec[i].1) - } - - pub fn keys(&self) -> impl Iterator { - self.vec.iter().map(|(k, _)| k) - } - - pub fn iter(&self) -> SliceIter<'_, (K, V)> { - self.vec.iter() - } -} diff --git a/src/lib.rs b/src/lib.rs index b241cdb8a..f969c0657 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -4,6 +4,8 @@ extern crate core; use std::sync::OnceLock; +use pyo3::exceptions::PyTypeError; +use pyo3::types::{PyByteArray, PyBytes, PyString}; use pyo3::{prelude::*, sync::GILOnceCell}; // parse this first to get access to the contained macro @@ -15,7 +17,6 @@ mod build_tools; mod definitions; mod errors; mod input; -mod lazy_index_map; mod lookup_key; mod recursion_guard; mod serializers; @@ -36,6 +37,19 @@ pub use serializers::{ }; pub use validators::{validate_core_schema, PySome, SchemaValidator}; +#[pyfunction(signature = (data, *, allow_inf_nan=true))] +pub fn from_json(py: Python, data: &PyAny, allow_inf_nan: bool) -> PyResult { + if let Ok(py_bytes) = data.downcast::() { + jiter::python_parse(py, py_bytes.as_bytes(), allow_inf_nan) + } else if let Ok(py_str) = data.downcast::() { + jiter::python_parse(py, py_str.to_str()?.as_bytes(), allow_inf_nan) + } else if let Ok(py_byte_array) = data.downcast::() { + jiter::python_parse(py, &py_byte_array.to_vec(), allow_inf_nan) + } else { + Err(PyTypeError::new_err("Expected bytes, bytearray or str")) + } +} + pub fn get_pydantic_core_version() -> &'static str { static PYDANTIC_CORE_VERSION: OnceLock = OnceLock::new(); @@ -95,6 +109,7 @@ fn _pydantic_core(py: Python, m: &PyModule) -> PyResult<()> { m.add_class::()?; m.add_class::()?; m.add_function(wrap_pyfunction!(to_json, m)?)?; + m.add_function(wrap_pyfunction!(from_json, m)?)?; m.add_function(wrap_pyfunction!(to_jsonable_python, m)?)?; m.add_function(wrap_pyfunction!(list_all_errors, m)?)?; m.add_function(wrap_pyfunction!(validate_core_schema, m)?)?; diff --git a/src/lookup_key.rs b/src/lookup_key.rs index 36190c069..f833c00af 100644 --- a/src/lookup_key.rs +++ b/src/lookup_key.rs @@ -5,9 +5,11 @@ use pyo3::exceptions::{PyAttributeError, PyTypeError}; use pyo3::prelude::*; use pyo3::types::{PyDict, PyList, PyMapping, PyString}; +use jiter::{JsonObject, JsonValue}; + use crate::build_tools::py_schema_err; use crate::errors::{py_err_string, ErrorType, ValError, ValLineError, ValResult}; -use crate::input::{Input, JsonInput, JsonObject, StringMapping}; +use crate::input::{Input, StringMapping}; use crate::tools::{extract_i64, py_err}; /// Used for getting items from python dicts, python objects, or JSON objects, in different ways @@ -111,7 +113,7 @@ impl LookupKey { dict: &'data PyDict, ) -> ValResult<'data, Option<(&'s LookupPath, &'data PyAny)>> { match self { - Self::Simple { py_key, path, .. } => match dict.get_item(py_key) { + Self::Simple { py_key, path, .. } => match dict.get_item(py_key)? { Some(value) => Ok(Some((path, value))), None => Ok(None), }, @@ -121,9 +123,9 @@ impl LookupKey { py_key2, path2, .. - } => match dict.get_item(py_key1) { + } => match dict.get_item(py_key1)? { Some(value) => Ok(Some((path1, value))), - None => match dict.get_item(py_key2) { + None => match dict.get_item(py_key2)? { Some(value) => Ok(Some((path2, value))), None => Ok(None), }, @@ -264,7 +266,7 @@ impl LookupKey { pub fn json_get<'data, 's>( &'s self, dict: &'data JsonObject, - ) -> ValResult<'data, Option<(&'s LookupPath, &'data JsonInput)>> { + ) -> ValResult<'data, Option<(&'s LookupPath, &'data JsonValue)>> { match self { Self::Simple { key, path, .. } => match dict.get(key) { Some(value) => Ok(Some((path, value))), @@ -289,13 +291,13 @@ impl LookupKey { // first step is different from the rest as we already know dict is JsonObject // because of above checks, we know that path should have at least one element, hence unwrap - let v: &JsonInput = match path_iter.next().unwrap().json_obj_get(dict) { + let v: &JsonValue = match path_iter.next().unwrap().json_obj_get(dict) { Some(v) => v, None => continue, }; // similar to above - // iterate over the path and plug each value into the JsonInput from the last step, starting with v + // iterate over the path and plug each value into the JsonValue from the last step, starting with v // from the first step, this could just be a loop but should be somewhat faster with a functional design if let Some(v) = path_iter.try_fold(v, |d, loc| loc.json_get(d)) { // Successfully found an item, return it @@ -481,10 +483,10 @@ impl PathItem { } } - pub fn json_get<'a>(&self, any_json: &'a JsonInput) -> Option<&'a JsonInput> { + pub fn json_get<'a>(&self, any_json: &'a JsonValue) -> Option<&'a JsonValue> { match any_json { - JsonInput::Object(v_obj) => self.json_obj_get(v_obj), - JsonInput::Array(v_array) => match self { + JsonValue::Object(v_obj) => self.json_obj_get(v_obj), + JsonValue::Array(v_array) => match self { Self::Pos(index) => v_array.get(*index), Self::Neg(index) => { if let Some(index) = v_array.len().checked_sub(*index) { @@ -499,7 +501,7 @@ impl PathItem { } } - pub fn json_obj_get<'a>(&self, json_obj: &'a JsonObject) -> Option<&'a JsonInput> { + pub fn json_obj_get<'a>(&self, json_obj: &'a JsonObject) -> Option<&'a JsonValue> { match self { Self::S(key, _) => json_obj.get(key), _ => None, diff --git a/src/py_gc.rs b/src/py_gc.rs index 02df02e13..8af285afb 100644 --- a/src/py_gc.rs +++ b/src/py_gc.rs @@ -1,3 +1,5 @@ +use std::sync::Arc; + use ahash::AHashMap; use enum_dispatch::enum_dispatch; use pyo3::{AsPyPointer, Py, PyTraverseError, PyVisit}; @@ -35,6 +37,12 @@ impl PyGcTraverse for AHashMap { } } +impl PyGcTraverse for Arc { + fn py_gc_traverse(&self, visit: &PyVisit<'_>) -> Result<(), PyTraverseError> { + T::py_gc_traverse(self, visit) + } +} + impl PyGcTraverse for Box { fn py_gc_traverse(&self, visit: &PyVisit<'_>) -> Result<(), PyTraverseError> { T::py_gc_traverse(self, visit) diff --git a/src/serializers/config.rs b/src/serializers/config.rs index 4f3611129..e83497f64 100644 --- a/src/serializers/config.rs +++ b/src/serializers/config.rs @@ -129,6 +129,7 @@ pub(crate) enum BytesMode { #[default] Utf8, Base64, + Hex, } impl FromStr for BytesMode { @@ -138,7 +139,11 @@ impl FromStr for BytesMode { match s { "utf8" => Ok(Self::Utf8), "base64" => Ok(Self::Base64), - s => py_schema_err!("Invalid bytes serialization mode: `{}`, expected `utf8` or `base64`", s), + "hex" => Ok(Self::Hex), + s => py_schema_err!( + "Invalid bytes serialization mode: `{}`, expected `utf8`, `base64` or `hex`", + s + ), } } } @@ -158,6 +163,9 @@ impl BytesMode { .map_err(|err| utf8_py_error(py, err, bytes)) .map(Cow::Borrowed), Self::Base64 => Ok(Cow::Owned(base64::engine::general_purpose::URL_SAFE.encode(bytes))), + Self::Hex => Ok(Cow::Owned( + bytes.iter().fold(String::new(), |acc, b| acc + &format!("{b:02x}")), + )), } } @@ -168,6 +176,9 @@ impl BytesMode { Err(e) => Err(Error::custom(e.to_string())), }, Self::Base64 => serializer.serialize_str(&base64::engine::general_purpose::URL_SAFE.encode(bytes)), + Self::Hex => { + serializer.serialize_str(&bytes.iter().fold(String::new(), |acc, b| acc + &format!("{b:02x}"))) + } } } } @@ -178,3 +189,32 @@ pub fn utf8_py_error(py: Python, err: Utf8Error, data: &[u8]) -> PyErr { Err(err) => err, } } + +#[derive(Default, Debug, Clone, PartialEq, Eq)] +pub(crate) enum InfNanMode { + #[default] + Null, + Constants, +} + +impl FromStr for InfNanMode { + type Err = PyErr; + + fn from_str(s: &str) -> Result { + match s { + "null" => Ok(Self::Null), + "constants" => Ok(Self::Constants), + s => py_schema_err!( + "Invalid inf_nan serialization mode: `{}`, expected `null` or `constants`", + s + ), + } + } +} + +impl FromPyObject<'_> for InfNanMode { + fn extract(ob: &'_ PyAny) -> PyResult { + let s = ob.extract::<&str>()?; + Self::from_str(s) + } +} diff --git a/src/serializers/errors.rs b/src/serializers/errors.rs index ac4ea784f..71a0a024e 100644 --- a/src/serializers/errors.rs +++ b/src/serializers/errors.rs @@ -14,8 +14,33 @@ pub(super) fn py_err_se_err(py_error: E) -> T { T::custom(py_error.to_string()) } +#[pyclass(extends=PyValueError, module="pydantic_core._pydantic_core")] +#[derive(Debug, Clone)] +pub struct PythonSerializerError { + pub message: String, +} + +impl fmt::Display for PythonSerializerError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}", self.message) + } +} + +impl std::error::Error for PythonSerializerError {} + +impl serde::ser::Error for PythonSerializerError { + fn custom(msg: T) -> Self + where + T: fmt::Display, + { + PythonSerializerError { + message: format!("{msg}"), + } + } +} + /// convert a serde serialization error into a `PyErr` -pub(super) fn se_err_py_err(error: serde_json::Error) -> PyErr { +pub(super) fn se_err_py_err(error: PythonSerializerError) -> PyErr { let s = error.to_string(); if let Some(msg) = s.strip_prefix(UNEXPECTED_TYPE_SER_MARKER) { if msg.is_empty() { diff --git a/src/serializers/extra.rs b/src/serializers/extra.rs index 9972a82c4..7a9b84704 100644 --- a/src/serializers/extra.rs +++ b/src/serializers/extra.rs @@ -2,16 +2,14 @@ use std::cell::RefCell; use std::fmt; use pyo3::exceptions::PyValueError; +use pyo3::intern; use pyo3::prelude::*; -use pyo3::{intern, AsPyPointer}; use serde::ser::Error; use super::config::SerializationConfig; use super::errors::{PydanticSerializationUnexpectedValue, UNEXPECTED_TYPE_SER_MARKER}; use super::ob_type::ObTypeLookup; -use super::shared::CombinedSerializer; -use crate::definitions::Definitions; use crate::recursion_guard::RecursionGuard; /// this is ugly, would be much better if extra could be stored in `SerializationState` @@ -48,7 +46,6 @@ impl SerializationState { Extra::new( py, mode, - &[], by_alias, &self.warnings, false, @@ -72,7 +69,6 @@ impl SerializationState { #[cfg_attr(debug_assertions, derive(Debug))] pub(crate) struct Extra<'a> { pub mode: &'a SerMode, - pub definitions: &'a Definitions, pub ob_type_lookup: &'a ObTypeLookup, pub warnings: &'a CollectWarnings, pub by_alias: bool, @@ -98,7 +94,6 @@ impl<'a> Extra<'a> { pub fn new( py: Python<'a>, mode: &'a SerMode, - definitions: &'a Definitions, by_alias: bool, warnings: &'a CollectWarnings, exclude_unset: bool, @@ -112,7 +107,6 @@ impl<'a> Extra<'a> { ) -> Self { Self { mode, - definitions, ob_type_lookup: ObTypeLookup::cached(py), warnings, by_alias, @@ -156,7 +150,6 @@ impl SerCheck { #[cfg_attr(debug_assertions, derive(Debug))] pub(crate) struct ExtraOwned { mode: SerMode, - definitions: Vec, warnings: CollectWarnings, by_alias: bool, exclude_unset: bool, @@ -176,7 +169,6 @@ impl ExtraOwned { pub fn new(extra: &Extra) -> Self { Self { mode: extra.mode.clone(), - definitions: extra.definitions.to_vec(), warnings: extra.warnings.clone(), by_alias: extra.by_alias, exclude_unset: extra.exclude_unset, @@ -196,7 +188,6 @@ impl ExtraOwned { pub fn to_extra<'py>(&'py self, py: Python<'py>) -> Extra<'py> { Extra { mode: &self.mode, - definitions: &self.definitions, ob_type_lookup: ObTypeLookup::cached(py), warnings: &self.warnings, by_alias: self.by_alias, diff --git a/src/serializers/filter.rs b/src/serializers/filter.rs index 89f923552..0efec56e8 100644 --- a/src/serializers/filter.rs +++ b/src/serializers/filter.rs @@ -60,8 +60,8 @@ impl SchemaFilter { let py = schema.py(); match schema.get_as::<&PyDict>(intern!(py, "serialization"))? { Some(ser) => { - let include = Self::build_set_ints(ser.get_item(intern!(py, "include")))?; - let exclude = Self::build_set_ints(ser.get_item(intern!(py, "exclude")))?; + let include = Self::build_set_ints(ser.get_item(intern!(py, "include"))?)?; + let exclude = Self::build_set_ints(ser.get_item(intern!(py, "exclude"))?)?; Ok(Self { include, exclude }) } None => Ok(SchemaFilter::default()), @@ -325,8 +325,8 @@ fn is_ellipsis_like(v: &PyAny) -> bool { /// lookup the dict, for the key and "__all__" key, and merge them following the same rules as pydantic V1 fn merge_all_value(dict: &PyDict, py_key: impl ToPyObject + Copy) -> PyResult> { - let op_item_value = dict.get_item(py_key); - let op_all_value = dict.get_item(intern!(dict.py(), "__all__")); + let op_item_value = dict.get_item(py_key)?; + let op_all_value = dict.get_item(intern!(dict.py(), "__all__"))?; match (op_item_value, op_all_value) { (Some(item_value), Some(all_value)) => { @@ -365,7 +365,7 @@ fn merge_dicts<'py>(item_dict: &'py PyDict, all_value: &'py PyAny) -> PyResult<& let item_dict = item_dict.copy()?; if let Ok(all_dict) = all_value.downcast::() { for (all_key, all_value) in all_dict { - if let Some(item_value) = item_dict.get_item(all_key) { + if let Some(item_value) = item_dict.get_item(all_key)? { if is_ellipsis_like(item_value) { continue; } diff --git a/src/serializers/mod.rs b/src/serializers/mod.rs index 6dbc076fe..e9208a510 100644 --- a/src/serializers/mod.rs +++ b/src/serializers/mod.rs @@ -5,7 +5,7 @@ use pyo3::prelude::*; use pyo3::types::{PyBytes, PyDict}; use pyo3::{PyTraverseError, PyVisit}; -use crate::definitions::DefinitionsBuilder; +use crate::definitions::{Definitions, DefinitionsBuilder}; use crate::py_gc::PyGcTraverse; use config::SerializationConfig; @@ -23,16 +23,21 @@ mod fields; mod filter; mod infer; mod ob_type; +pub mod ser; mod shared; mod type_serializers; -#[pyclass(module = "pydantic_core._pydantic_core")] +#[pyclass(module = "pydantic_core._pydantic_core", frozen)] #[derive(Debug)] pub struct SchemaSerializer { serializer: CombinedSerializer, - definitions: Vec, + definitions: Definitions, expected_json_size: AtomicUsize, config: SerializationConfig, + // References to the Python schema and config objects are saved to enable + // reconstructing the object for pickle support (see `__reduce__`). + py_schema: Py, + py_config: Option>, } impl SchemaSerializer { @@ -54,7 +59,6 @@ impl SchemaSerializer { Extra::new( py, mode, - &self.definitions, by_alias, warnings, exclude_unset, @@ -72,15 +76,19 @@ impl SchemaSerializer { #[pymethods] impl SchemaSerializer { #[new] - pub fn py_new(schema: &PyDict, config: Option<&PyDict>) -> PyResult { + pub fn py_new(py: Python, schema: &PyDict, config: Option<&PyDict>) -> PyResult { let mut definitions_builder = DefinitionsBuilder::new(); - let serializer = CombinedSerializer::build(schema.downcast()?, config, &mut definitions_builder)?; Ok(Self { serializer, definitions: definitions_builder.finish()?, expected_json_size: AtomicUsize::new(1024), config: SerializationConfig::from_config(config)?, + py_schema: schema.into_py(py), + py_config: match config { + Some(c) if !c.is_empty() => Some(c.into_py(py)), + _ => None, + }, }) } @@ -175,6 +183,14 @@ impl SchemaSerializer { Ok(py_bytes.into()) } + pub fn __reduce__(slf: &PyCell) -> PyResult<(PyObject, (PyObject, PyObject))> { + // Enables support for `pickle` serialization. + let py = slf.py(); + let cls = slf.get_type().into(); + let init_args = (slf.get().py_schema.to_object(py), slf.get().py_config.to_object(py)); + Ok((cls, init_args)) + } + pub fn __repr__(&self) -> String { format!( "SchemaSerializer(serializer={:#?}, definitions={:#?})", @@ -183,10 +199,12 @@ impl SchemaSerializer { } fn __traverse__(&self, visit: PyVisit<'_>) -> Result<(), PyTraverseError> { - self.serializer.py_gc_traverse(&visit)?; - for slot in &self.definitions { - slot.py_gc_traverse(&visit)?; + visit.call(&self.py_schema)?; + if let Some(ref py_config) = self.py_config { + visit.call(py_config)?; } + self.serializer.py_gc_traverse(&visit)?; + self.definitions.py_gc_traverse(&visit)?; Ok(()) } } diff --git a/src/serializers/ob_type.rs b/src/serializers/ob_type.rs index fc491f618..ff43a1065 100644 --- a/src/serializers/ob_type.rs +++ b/src/serializers/ob_type.rs @@ -259,8 +259,9 @@ impl ObTypeLookup { fn is_enum(&self, op_value: Option<&PyAny>, py_type: &PyType) -> bool { // only test on the type itself, not base types if op_value.is_some() { + let enum_meta_type = self.enum_object.as_ref(py_type.py()).get_type(); let meta_type = py_type.get_type(); - meta_type.is(&self.enum_object) + meta_type.is(enum_meta_type) } else { false } @@ -332,6 +333,7 @@ fn is_dataclass(op_value: Option<&PyAny>) -> bool { value .hasattr(intern!(value.py(), "__dataclass_fields__")) .unwrap_or(false) + && !value.is_instance_of::() } else { false } @@ -342,6 +344,7 @@ fn is_pydantic_serializable(op_value: Option<&PyAny>) -> bool { value .hasattr(intern!(value.py(), "__pydantic_serializer__")) .unwrap_or(false) + && !value.is_instance_of::() } else { false } diff --git a/src/serializers/ser.rs b/src/serializers/ser.rs new file mode 100644 index 000000000..170cd1849 --- /dev/null +++ b/src/serializers/ser.rs @@ -0,0 +1,1299 @@ +use std::{io, num::FpCategory}; + +use serde::{ser::Impossible, serde_if_integer128, Serialize, Serializer}; +use serde_json::ser::{CompactFormatter, Formatter, PrettyFormatter, State}; + +use super::errors::PythonSerializerError; + +macro_rules! tri { + ($e:expr $(,)?) => { + match $e { + core::result::Result::Ok(val) => val, + core::result::Result::Err(err) => return core::result::Result::Err(err), + } + }; +} + +type Result = std::result::Result; +const TOKEN: &str = "$serde_json::private::Number"; +pub struct PythonSerializer { + writer: W, + formatter: F, +} + +impl PythonSerializer +where + W: io::Write, +{ + /// Creates a new JSON serializer. + #[inline] + pub fn new(writer: W) -> Self { + PythonSerializer::with_formatter(writer, CompactFormatter) + } +} + +impl<'a, W> PythonSerializer> +where + W: io::Write, +{ + /// Creates a new JSON pretty print serializer. + #[inline] + pub fn pretty(writer: W) -> Self { + PythonSerializer::with_formatter(writer, PrettyFormatter::new()) + } +} + +impl PythonSerializer +where + W: io::Write, + F: Formatter, +{ + /// Creates a new JSON visitor whose output will be written to the writer + /// specified. + #[inline] + pub fn with_formatter(writer: W, formatter: F) -> Self { + PythonSerializer { writer, formatter } + } + + /// Unwrap the `Writer` from the `Serializer`. + #[inline] + pub fn into_inner(self) -> W { + self.writer + } +} + +impl<'a, W, F> Serializer for &'a mut PythonSerializer +where + W: io::Write, + F: Formatter, +{ + type Ok = (); + type Error = PythonSerializerError; + + type SerializeSeq = Compound<'a, W, F>; + type SerializeTuple = Compound<'a, W, F>; + type SerializeTupleStruct = Compound<'a, W, F>; + type SerializeTupleVariant = Compound<'a, W, F>; + type SerializeMap = Compound<'a, W, F>; + type SerializeStruct = Compound<'a, W, F>; + type SerializeStructVariant = Compound<'a, W, F>; + + #[inline] + fn serialize_bool(self, value: bool) -> Result<()> { + self.formatter + .write_bool(&mut self.writer, value) + .map_err(|e| PythonSerializerError { message: e.to_string() }) + } + + #[inline] + fn serialize_i8(self, value: i8) -> Result<()> { + self.formatter + .write_i8(&mut self.writer, value) + .map_err(|e| PythonSerializerError { message: e.to_string() }) + } + + fn serialize_i16(self, value: i16) -> Result { + self.formatter + .write_i16(&mut self.writer, value) + .map_err(|e| PythonSerializerError { message: e.to_string() }) + } + + fn serialize_i32(self, value: i32) -> Result { + self.formatter + .write_i32(&mut self.writer, value) + .map_err(|e| PythonSerializerError { message: e.to_string() }) + } + + fn serialize_i64(self, value: i64) -> Result { + self.formatter + .write_i64(&mut self.writer, value) + .map_err(|e| PythonSerializerError { message: e.to_string() }) + } + + fn serialize_u8(self, value: u8) -> Result { + self.formatter + .write_u8(&mut self.writer, value) + .map_err(|e| PythonSerializerError { message: e.to_string() }) + } + + fn serialize_u16(self, value: u16) -> Result { + self.formatter + .write_u16(&mut self.writer, value) + .map_err(|e| PythonSerializerError { message: e.to_string() }) + } + + fn serialize_u32(self, value: u32) -> Result { + self.formatter + .write_u32(&mut self.writer, value) + .map_err(|e| PythonSerializerError { message: e.to_string() }) + } + + fn serialize_u64(self, value: u64) -> Result { + self.formatter + .write_u64(&mut self.writer, value) + .map_err(|e| PythonSerializerError { message: e.to_string() }) + } + + fn serialize_u128(self, value: u128) -> Result<()> { + self.formatter + .write_u128(&mut self.writer, value) + .map_err(|e| PythonSerializerError { message: e.to_string() }) + } + + #[inline] + fn serialize_f32(self, value: f32) -> Result<()> { + match value.classify() { + FpCategory::Nan => self + .formatter + .write_number_str(&mut self.writer, "NaN") + .map_err(|e| PythonSerializerError { message: e.to_string() }), + FpCategory::Infinite => { + let infinity = if value.is_sign_negative() { + "-Infinity" + } else { + "Infinity" + }; + self.formatter + .write_number_str(&mut self.writer, infinity) + .map_err(|e| PythonSerializerError { message: e.to_string() }) + } + _ => self + .formatter + .write_f32(&mut self.writer, value) + .map_err(|e| PythonSerializerError { message: e.to_string() }), + } + } + + fn serialize_f64(self, value: f64) -> Result { + match value.classify() { + FpCategory::Nan => self + .formatter + .write_number_str(&mut self.writer, "NaN") + .map_err(|e| PythonSerializerError { message: e.to_string() }), + FpCategory::Infinite => { + let infinity = if value.is_sign_negative() { + "-Infinity" + } else { + "Infinity" + }; + self.formatter + .write_number_str(&mut self.writer, infinity) + .map_err(|e| PythonSerializerError { message: e.to_string() }) + } + _ => self + .formatter + .write_f64(&mut self.writer, value) + .map_err(|e| PythonSerializerError { message: e.to_string() }), + } + } + + fn serialize_char(self, value: char) -> Result { + // A char encoded as UTF-8 takes 4 bytes at most. + let mut buf = [0; 4]; + self.serialize_str(value.encode_utf8(&mut buf)) + } + + fn serialize_str(self, value: &str) -> Result { + format_escaped_str(&mut self.writer, &mut self.formatter, value) + .map_err(|e| PythonSerializerError { message: e.to_string() }) + } + + fn serialize_bytes(self, value: &[u8]) -> Result<()> { + self.formatter + .write_byte_array(&mut self.writer, value) + .map_err(|e| PythonSerializerError { message: e.to_string() }) + } + + fn serialize_none(self) -> Result { + self.formatter + .write_null(&mut self.writer) + .map_err(|e| PythonSerializerError { message: e.to_string() }) + } + + fn serialize_some(self, value: &T) -> Result + where + T: ?Sized + Serialize, + { + value.serialize(self) + } + + fn serialize_unit(self) -> Result { + self.formatter + .write_null(&mut self.writer) + .map_err(|e| PythonSerializerError { message: e.to_string() }) + } + + fn serialize_unit_struct(self, _name: &'static str) -> Result { + self.serialize_unit() + } + + fn serialize_unit_variant( + self, + _name: &'static str, + _variant_index: u32, + variant: &'static str, + ) -> Result { + self.serialize_str(variant) + } + + fn serialize_newtype_struct(self, _name: &'static str, value: &T) -> Result + where + T: Serialize, + { + value.serialize(self) + } + + fn serialize_newtype_variant( + self, + _name: &'static str, + _variant_index: u32, + variant: &'static str, + value: &T, + ) -> Result + where + T: Serialize, + { + tri!(self + .formatter + .begin_object(&mut self.writer) + .map_err(|e| PythonSerializerError { message: e.to_string() })); + tri!(self + .formatter + .begin_object_key(&mut self.writer, true) + .map_err(|e| PythonSerializerError { message: e.to_string() })); + tri!(self.serialize_str(variant)); + tri!(self + .formatter + .end_object_key(&mut self.writer) + .map_err(|e| PythonSerializerError { message: e.to_string() })); + tri!(self + .formatter + .begin_object_value(&mut self.writer) + .map_err(|e| PythonSerializerError { message: e.to_string() })); + tri!(value.serialize(&mut *self)); + tri!(self + .formatter + .end_object_value(&mut self.writer) + .map_err(|e| PythonSerializerError { message: e.to_string() })); + self.formatter + .end_object(&mut self.writer) + .map_err(|e| PythonSerializerError { message: e.to_string() }) + } + + fn serialize_seq(self, len: Option) -> Result { + tri!(self + .formatter + .begin_array(&mut self.writer) + .map_err(|e| PythonSerializerError { message: e.to_string() })); + if len == Some(0) { + tri!(self + .formatter + .end_array(&mut self.writer) + .map_err(|e| PythonSerializerError { message: e.to_string() })); + Ok(Compound::Map { + ser: self, + state: State::Empty, + }) + } else { + Ok(Compound::Map { + ser: self, + state: State::First, + }) + } + } + + fn serialize_tuple(self, len: usize) -> Result { + self.serialize_seq(Some(len)) + } + + fn serialize_tuple_struct(self, _name: &'static str, len: usize) -> Result { + self.serialize_seq(Some(len)) + } + + fn serialize_tuple_variant( + self, + _name: &'static str, + _variant_index: u32, + variant: &'static str, + len: usize, + ) -> Result { + tri!(self + .formatter + .begin_object(&mut self.writer) + .map_err(|e| PythonSerializerError { message: e.to_string() })); + tri!(self + .formatter + .begin_object_key(&mut self.writer, true) + .map_err(|e| PythonSerializerError { message: e.to_string() })); + tri!(self.serialize_str(variant)); + tri!(self + .formatter + .end_object_key(&mut self.writer) + .map_err(|e| PythonSerializerError { message: e.to_string() })); + tri!(self + .formatter + .begin_object_value(&mut self.writer) + .map_err(|e| PythonSerializerError { message: e.to_string() })); + self.serialize_seq(Some(len)) + } + + fn serialize_map(self, len: Option) -> Result { + tri!(self + .formatter + .begin_object(&mut self.writer) + .map_err(|e| PythonSerializerError { message: e.to_string() })); + if len == Some(0) { + tri!(self + .formatter + .end_object(&mut self.writer) + .map_err(|e| PythonSerializerError { message: e.to_string() })); + Ok(Compound::Map { + ser: self, + state: State::Empty, + }) + } else { + Ok(Compound::Map { + ser: self, + state: State::First, + }) + } + } + + fn serialize_struct(self, name: &'static str, len: usize) -> Result { + match name { + TOKEN => Ok(Compound::Number { ser: self }), + _ => self.serialize_map(Some(len)), + } + } + + fn serialize_struct_variant( + self, + _name: &'static str, + _variant_index: u32, + variant: &'static str, + len: usize, + ) -> Result { + tri!(self + .formatter + .begin_object(&mut self.writer) + .map_err(|e| PythonSerializerError { message: e.to_string() })); + tri!(self + .formatter + .begin_object_key(&mut self.writer, true) + .map_err(|e| PythonSerializerError { message: e.to_string() })); + tri!(self.serialize_str(variant)); + tri!(self + .formatter + .end_object_key(&mut self.writer) + .map_err(|e| PythonSerializerError { message: e.to_string() })); + tri!(self + .formatter + .begin_object_value(&mut self.writer) + .map_err(|e| PythonSerializerError { message: e.to_string() })); + self.serialize_map(Some(len)) + } +} + +impl<'a, W, F> serde::ser::SerializeSeq for Compound<'a, W, F> +where + W: io::Write, + F: Formatter, +{ + type Ok = (); + type Error = PythonSerializerError; + + #[inline] + fn serialize_element(&mut self, value: &T) -> Result<()> + where + T: ?Sized + Serialize, + { + match self { + Compound::Map { ser, state } => { + tri!(ser + .formatter + .begin_array_value(&mut ser.writer, *state == State::First) + .map_err(|e| PythonSerializerError { message: e.to_string() })); + *state = State::Rest; + tri!(value.serialize(&mut **ser)); + tri!(ser + .formatter + .end_array_value(&mut ser.writer) + .map_err(|e| PythonSerializerError { message: e.to_string() })); + Ok(()) + } + Compound::Number { .. } => unreachable!(), + } + } + + fn end(self) -> Result<()> { + match self { + Compound::Map { ser, state } => { + match state { + State::Empty => {} + _ => tri!(ser + .formatter + .end_array(&mut ser.writer) + .map_err(|e| PythonSerializerError { message: e.to_string() })), + } + Ok(()) + } + Compound::Number { .. } => unreachable!(), + } + } +} + +impl<'a, W, F> serde::ser::SerializeTuple for Compound<'a, W, F> +where + W: io::Write, + F: Formatter, +{ + type Ok = (); + type Error = PythonSerializerError; + + #[inline] + fn serialize_element(&mut self, value: &T) -> Result<()> + where + T: ?Sized + Serialize, + { + serde::ser::SerializeSeq::serialize_element(self, value) + } + + #[inline] + fn end(self) -> Result<()> { + serde::ser::SerializeSeq::end(self) + } +} + +impl<'a, W, F> serde::ser::SerializeTupleStruct for Compound<'a, W, F> +where + W: io::Write, + F: Formatter, +{ + type Ok = (); + type Error = PythonSerializerError; + + #[inline] + fn serialize_field(&mut self, value: &T) -> Result<()> + where + T: ?Sized + Serialize, + { + serde::ser::SerializeSeq::serialize_element(self, value) + } + + #[inline] + fn end(self) -> Result<()> { + serde::ser::SerializeSeq::end(self) + } +} + +impl<'a, W, F> serde::ser::SerializeTupleVariant for Compound<'a, W, F> +where + W: io::Write, + F: Formatter, +{ + type Ok = (); + type Error = PythonSerializerError; + + #[inline] + fn serialize_field(&mut self, value: &T) -> Result<()> + where + T: ?Sized + Serialize, + { + serde::ser::SerializeSeq::serialize_element(self, value) + } + + #[inline] + fn end(self) -> Result<()> { + match self { + Compound::Map { ser, state } => { + match state { + State::Empty => {} + _ => tri!(ser + .formatter + .end_array(&mut ser.writer) + .map_err(|e| PythonSerializerError { message: e.to_string() })), + } + tri!(ser + .formatter + .end_object_value(&mut ser.writer) + .map_err(|e| PythonSerializerError { message: e.to_string() })); + tri!(ser + .formatter + .end_object(&mut ser.writer) + .map_err(|e| PythonSerializerError { message: e.to_string() })); + Ok(()) + } + Compound::Number { .. } => unreachable!(), + } + } +} + +impl<'a, W, F> serde::ser::SerializeMap for Compound<'a, W, F> +where + W: io::Write, + F: Formatter, +{ + type Ok = (); + type Error = PythonSerializerError; + + #[inline] + fn serialize_key(&mut self, key: &T) -> Result<()> + where + T: ?Sized + Serialize, + { + match self { + Compound::Map { ser, state } => { + tri!(ser + .formatter + .begin_object_key(&mut ser.writer, *state == State::First) + .map_err(|e| PythonSerializerError { message: e.to_string() })); + *state = State::Rest; + + tri!(key.serialize(MapKeySerializer { ser: *ser })); + + tri!(ser + .formatter + .end_object_key(&mut ser.writer) + .map_err(|e| PythonSerializerError { message: e.to_string() })); + Ok(()) + } + Compound::Number { .. } => unreachable!(), + } + } + + #[inline] + fn serialize_value(&mut self, value: &T) -> Result<()> + where + T: ?Sized + Serialize, + { + match self { + Compound::Map { ser, .. } => { + tri!(ser + .formatter + .begin_object_value(&mut ser.writer) + .map_err(|e| PythonSerializerError { message: e.to_string() })); + tri!(value.serialize(&mut **ser)); + tri!(ser + .formatter + .end_object_value(&mut ser.writer) + .map_err(|e| PythonSerializerError { message: e.to_string() })); + Ok(()) + } + Compound::Number { .. } => unreachable!(), + } + } + + #[inline] + fn end(self) -> Result<()> { + match self { + Compound::Map { ser, state } => { + match state { + State::Empty => {} + _ => tri!(ser + .formatter + .end_object(&mut ser.writer) + .map_err(|e| PythonSerializerError { message: e.to_string() })), + } + Ok(()) + } + Compound::Number { .. } => unreachable!(), + } + } +} + +impl<'a, W, F> serde::ser::SerializeStruct for Compound<'a, W, F> +where + W: io::Write, + F: Formatter, +{ + type Ok = (); + type Error = PythonSerializerError; + + #[inline] + fn serialize_field(&mut self, key: &'static str, value: &T) -> Result<()> + where + T: ?Sized + Serialize, + { + match self { + Compound::Map { .. } => serde::ser::SerializeMap::serialize_entry(self, key, value), + Compound::Number { ser, .. } => { + if key == TOKEN { + tri!(value.serialize(NumberStrEmitter(ser))); + Ok(()) + } else { + Err(invalid_number()) + } + } + } + } + + #[inline] + fn end(self) -> Result<()> { + match self { + Compound::Map { .. } => serde::ser::SerializeMap::end(self), + Compound::Number { .. } => Ok(()), + } + } +} + +impl<'a, W, F> serde::ser::SerializeStructVariant for Compound<'a, W, F> +where + W: io::Write, + F: Formatter, +{ + type Ok = (); + type Error = PythonSerializerError; + + #[inline] + fn serialize_field(&mut self, key: &'static str, value: &T) -> Result<()> + where + T: ?Sized + Serialize, + { + match *self { + Compound::Map { .. } => serde::ser::SerializeStruct::serialize_field(self, key, value), + Compound::Number { .. } => unreachable!(), + } + } + + #[inline] + fn end(self) -> Result<()> { + match self { + Compound::Map { ser, state } => { + match state { + State::Empty => {} + _ => tri!(ser + .formatter + .end_object(&mut ser.writer) + .map_err(|e| PythonSerializerError { message: e.to_string() })), + } + tri!(ser + .formatter + .end_object_value(&mut ser.writer) + .map_err(|e| PythonSerializerError { message: e.to_string() })); + tri!(ser + .formatter + .end_object(&mut ser.writer) + .map_err(|e| PythonSerializerError { message: e.to_string() })); + Ok(()) + } + Compound::Number { .. } => unreachable!(), + } + } +} + +fn format_escaped_str(writer: &mut W, formatter: &mut F, value: &str) -> io::Result<()> +where + W: ?Sized + io::Write, + F: ?Sized + Formatter, +{ + tri!(formatter.begin_string(writer)); + tri!(format_escaped_str_contents(writer, formatter, value)); + formatter.end_string(writer) +} + +fn format_escaped_str_contents(writer: &mut W, formatter: &mut F, value: &str) -> io::Result<()> +where + W: ?Sized + io::Write, + F: ?Sized + Formatter, +{ + let bytes = value.as_bytes(); + + let mut start = 0; + + for (i, &byte) in bytes.iter().enumerate() { + let escape = ESCAPE[byte as usize]; + if escape == 0 { + continue; + } + + if start < i { + tri!(formatter.write_string_fragment(writer, &value[start..i])); + } + + let char_escape = CharEscape::from_escape_table(escape, byte); + tri!(formatter.write_char_escape(writer, char_escape)); + + start = i + 1; + } + + if start == bytes.len() { + return Ok(()); + } + + formatter.write_string_fragment(writer, &value[start..]) +} + +const BB: u8 = b'b'; // \x08 +const TT: u8 = b't'; // \x09 +const NN: u8 = b'n'; // \x0A +const FF: u8 = b'f'; // \x0C +const RR: u8 = b'r'; // \x0D +const QU: u8 = b'"'; // \x22 +const BS: u8 = b'\\'; // \x5C +const UU: u8 = b'u'; // \x00...\x1F except the ones above +const __: u8 = 0; + +// Lookup table of escape sequences. A value of b'x' at index i means that byte +// i is escaped as "\x" in JSON. A value of 0 means that byte i is not escaped. +static ESCAPE: [u8; 256] = [ + // 1 2 3 4 5 6 7 8 9 A B C D E F + UU, UU, UU, UU, UU, UU, UU, UU, BB, TT, NN, UU, FF, RR, UU, UU, // 0 + UU, UU, UU, UU, UU, UU, UU, UU, UU, UU, UU, UU, UU, UU, UU, UU, // 1 + __, __, QU, __, __, __, __, __, __, __, __, __, __, __, __, __, // 2 + __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, // 3 + __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, // 4 + __, __, __, __, __, __, __, __, __, __, __, __, BS, __, __, __, // 5 + __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, // 6 + __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, // 7 + __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, // 8 + __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, // 9 + __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, // A + __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, // B + __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, // C + __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, // D + __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, // E + __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, // F +]; + +pub enum Compound<'a, W: 'a, F: 'a> { + Map { + ser: &'a mut PythonSerializer, + state: State, + }, + Number { + ser: &'a mut PythonSerializer, + }, +} + +/// Represents a character escape code in a type-safe manner. +pub enum CharEscape {} + +impl CharEscape { + #[inline] + fn from_escape_table(escape: u8, byte: u8) -> serde_json::ser::CharEscape { + match escape { + self::BB => serde_json::ser::CharEscape::Backspace, + self::TT => serde_json::ser::CharEscape::Tab, + self::NN => serde_json::ser::CharEscape::LineFeed, + self::FF => serde_json::ser::CharEscape::FormFeed, + self::RR => serde_json::ser::CharEscape::CarriageReturn, + self::QU => serde_json::ser::CharEscape::Quote, + self::BS => serde_json::ser::CharEscape::ReverseSolidus, + self::UU => serde_json::ser::CharEscape::AsciiControl(byte), + _ => unreachable!(), + } + } +} + +struct MapKeySerializer<'a, W: 'a, F: 'a> { + ser: &'a mut PythonSerializer, +} + +fn key_must_be_a_string() -> PythonSerializerError { + PythonSerializerError { + message: "Key must be a string".to_string(), + } +} +fn invalid_number() -> PythonSerializerError { + PythonSerializerError { + message: "Invalid Number".to_string(), + } +} + +impl<'a, W, F> serde::ser::Serializer for MapKeySerializer<'a, W, F> +where + W: io::Write, + F: Formatter, +{ + type Ok = (); + type Error = PythonSerializerError; + + #[inline] + fn serialize_str(self, value: &str) -> Result<()> { + self.ser.serialize_str(value) + } + + #[inline] + fn serialize_unit_variant(self, _name: &'static str, _variant_index: u32, variant: &'static str) -> Result<()> { + self.ser.serialize_str(variant) + } + + #[inline] + fn serialize_newtype_struct(self, _name: &'static str, value: &T) -> Result<()> + where + T: ?Sized + Serialize, + { + value.serialize(self) + } + + type SerializeSeq = Impossible<(), PythonSerializerError>; + type SerializeTuple = Impossible<(), PythonSerializerError>; + type SerializeTupleStruct = Impossible<(), PythonSerializerError>; + type SerializeTupleVariant = Impossible<(), PythonSerializerError>; + type SerializeMap = Impossible<(), PythonSerializerError>; + type SerializeStruct = Impossible<(), PythonSerializerError>; + type SerializeStructVariant = Impossible<(), PythonSerializerError>; + + fn serialize_bool(self, _value: bool) -> Result<()> { + Err(key_must_be_a_string()) + } + + fn serialize_i8(self, value: i8) -> Result<()> { + tri!(self + .ser + .formatter + .begin_string(&mut self.ser.writer) + .map_err(|e| PythonSerializerError { message: e.to_string() })); + tri!(self + .ser + .formatter + .write_i8(&mut self.ser.writer, value) + .map_err(|e| PythonSerializerError { message: e.to_string() })); + tri!(self + .ser + .formatter + .end_string(&mut self.ser.writer) + .map_err(|e| PythonSerializerError { message: e.to_string() })); + Ok(()) + } + + fn serialize_i16(self, value: i16) -> Result<()> { + tri!(self + .ser + .formatter + .begin_string(&mut self.ser.writer) + .map_err(|e| PythonSerializerError { message: e.to_string() })); + tri!(self + .ser + .formatter + .write_i16(&mut self.ser.writer, value) + .map_err(|e| PythonSerializerError { message: e.to_string() })); + tri!(self + .ser + .formatter + .end_string(&mut self.ser.writer) + .map_err(|e| PythonSerializerError { message: e.to_string() })); + Ok(()) + } + + fn serialize_i32(self, value: i32) -> Result<()> { + tri!(self + .ser + .formatter + .begin_string(&mut self.ser.writer) + .map_err(|e| PythonSerializerError { message: e.to_string() })); + tri!(self + .ser + .formatter + .write_i32(&mut self.ser.writer, value) + .map_err(|e| PythonSerializerError { message: e.to_string() })); + tri!(self + .ser + .formatter + .end_string(&mut self.ser.writer) + .map_err(|e| PythonSerializerError { message: e.to_string() })); + Ok(()) + } + + fn serialize_i64(self, value: i64) -> Result<()> { + tri!(self + .ser + .formatter + .begin_string(&mut self.ser.writer) + .map_err(|e| PythonSerializerError { message: e.to_string() })); + tri!(self + .ser + .formatter + .write_i64(&mut self.ser.writer, value) + .map_err(|e| PythonSerializerError { message: e.to_string() })); + tri!(self + .ser + .formatter + .end_string(&mut self.ser.writer) + .map_err(|e| PythonSerializerError { message: e.to_string() })); + Ok(()) + } + + serde_if_integer128! { + fn serialize_i128(self, value: i128) -> Result<()> { + tri!(self + .ser + .formatter + .begin_string(&mut self.ser.writer) + .map_err(|e| PythonSerializerError { message: e.to_string() })); + tri!(self + .ser + .formatter + .write_number_str(&mut self.ser.writer, &value.to_string()) + .map_err(|e| PythonSerializerError { message: e.to_string() })); + tri!(self + .ser + .formatter + .end_string(&mut self.ser.writer) + .map_err(|e| PythonSerializerError { message: e.to_string() })); + Ok(()) + } + } + + fn serialize_u8(self, value: u8) -> Result<()> { + tri!(self + .ser + .formatter + .begin_string(&mut self.ser.writer) + .map_err(|e| PythonSerializerError { message: e.to_string() })); + tri!(self + .ser + .formatter + .write_u8(&mut self.ser.writer, value) + .map_err(|e| PythonSerializerError { message: e.to_string() })); + tri!(self + .ser + .formatter + .end_string(&mut self.ser.writer) + .map_err(|e| PythonSerializerError { message: e.to_string() })); + Ok(()) + } + + fn serialize_u16(self, value: u16) -> Result<()> { + tri!(self + .ser + .formatter + .begin_string(&mut self.ser.writer) + .map_err(|e| PythonSerializerError { message: e.to_string() })); + tri!(self + .ser + .formatter + .write_u16(&mut self.ser.writer, value) + .map_err(|e| PythonSerializerError { message: e.to_string() })); + tri!(self + .ser + .formatter + .end_string(&mut self.ser.writer) + .map_err(|e| PythonSerializerError { message: e.to_string() })); + Ok(()) + } + + fn serialize_u32(self, value: u32) -> Result<()> { + tri!(self + .ser + .formatter + .begin_string(&mut self.ser.writer) + .map_err(|e| PythonSerializerError { message: e.to_string() })); + tri!(self + .ser + .formatter + .write_u32(&mut self.ser.writer, value) + .map_err(|e| PythonSerializerError { message: e.to_string() })); + tri!(self + .ser + .formatter + .end_string(&mut self.ser.writer) + .map_err(|e| PythonSerializerError { message: e.to_string() })); + Ok(()) + } + + fn serialize_u64(self, value: u64) -> Result<()> { + tri!(self + .ser + .formatter + .begin_string(&mut self.ser.writer) + .map_err(|e| PythonSerializerError { message: e.to_string() })); + tri!(self + .ser + .formatter + .write_u64(&mut self.ser.writer, value) + .map_err(|e| PythonSerializerError { message: e.to_string() })); + tri!(self + .ser + .formatter + .end_string(&mut self.ser.writer) + .map_err(|e| PythonSerializerError { message: e.to_string() })); + Ok(()) + } + + serde_if_integer128! { + fn serialize_u128(self, value: u128) -> Result<()> { + tri!(self + .ser + .formatter + .begin_string(&mut self.ser.writer) + .map_err(|e| PythonSerializerError { message: e.to_string() })); + tri!(self + .ser + .formatter + .write_number_str(&mut self.ser.writer, &value.to_string()) + .map_err(|e| PythonSerializerError { message: e.to_string() })); + tri!(self + .ser + .formatter + .end_string(&mut self.ser.writer) + .map_err(|e| PythonSerializerError { message: e.to_string() })); + Ok(()) + } + } + + fn serialize_f32(self, _value: f32) -> Result<()> { + Err(key_must_be_a_string()) + } + + fn serialize_f64(self, _value: f64) -> Result<()> { + Err(key_must_be_a_string()) + } + + fn serialize_char(self, value: char) -> Result<()> { + self.ser.serialize_str(&value.to_string()) + } + + fn serialize_bytes(self, _value: &[u8]) -> Result<()> { + Err(key_must_be_a_string()) + } + + fn serialize_unit(self) -> Result<()> { + Err(key_must_be_a_string()) + } + + fn serialize_unit_struct(self, _name: &'static str) -> Result<()> { + Err(key_must_be_a_string()) + } + + fn serialize_newtype_variant( + self, + _name: &'static str, + _variant_index: u32, + _variant: &'static str, + _value: &T, + ) -> Result<()> + where + T: ?Sized + Serialize, + { + Err(key_must_be_a_string()) + } + + fn serialize_none(self) -> Result<()> { + Err(key_must_be_a_string()) + } + + fn serialize_some(self, _value: &T) -> Result<()> + where + T: ?Sized + Serialize, + { + Err(key_must_be_a_string()) + } + + fn serialize_seq(self, _len: Option) -> Result { + Err(key_must_be_a_string()) + } + + fn serialize_tuple(self, _len: usize) -> Result { + Err(key_must_be_a_string()) + } + + fn serialize_tuple_struct(self, _name: &'static str, _len: usize) -> Result { + Err(key_must_be_a_string()) + } + + fn serialize_tuple_variant( + self, + _name: &'static str, + _variant_index: u32, + _variant: &'static str, + _len: usize, + ) -> Result { + Err(key_must_be_a_string()) + } + + fn serialize_map(self, _len: Option) -> Result { + Err(key_must_be_a_string()) + } + + fn serialize_struct(self, _name: &'static str, _len: usize) -> Result { + Err(key_must_be_a_string()) + } + + fn serialize_struct_variant( + self, + _name: &'static str, + _variant_index: u32, + _variant: &'static str, + _len: usize, + ) -> Result { + Err(key_must_be_a_string()) + } + + fn collect_str(self, value: &T) -> Result<()> + where + T: ?Sized + std::fmt::Display, + { + self.ser.collect_str(value) + } +} + +struct NumberStrEmitter<'a, W: 'a + io::Write, F: 'a + Formatter>(&'a mut PythonSerializer); + +impl<'a, W: io::Write, F: Formatter> serde::ser::Serializer for NumberStrEmitter<'a, W, F> { + type Ok = (); + type Error = PythonSerializerError; + + type SerializeSeq = Impossible<(), PythonSerializerError>; + type SerializeTuple = Impossible<(), PythonSerializerError>; + type SerializeTupleStruct = Impossible<(), PythonSerializerError>; + type SerializeTupleVariant = Impossible<(), PythonSerializerError>; + type SerializeMap = Impossible<(), PythonSerializerError>; + type SerializeStruct = Impossible<(), PythonSerializerError>; + type SerializeStructVariant = Impossible<(), PythonSerializerError>; + + fn serialize_bool(self, _v: bool) -> Result<()> { + Err(invalid_number()) + } + + fn serialize_i8(self, _v: i8) -> Result<()> { + Err(invalid_number()) + } + + fn serialize_i16(self, _v: i16) -> Result<()> { + Err(invalid_number()) + } + + fn serialize_i32(self, _v: i32) -> Result<()> { + Err(invalid_number()) + } + + fn serialize_i64(self, _v: i64) -> Result<()> { + Err(invalid_number()) + } + + serde_if_integer128! { + fn serialize_i128(self, _v: i128) -> Result<()> { + Err(invalid_number()) + } + } + + fn serialize_u8(self, _v: u8) -> Result<()> { + Err(invalid_number()) + } + + fn serialize_u16(self, _v: u16) -> Result<()> { + Err(invalid_number()) + } + + fn serialize_u32(self, _v: u32) -> Result<()> { + Err(invalid_number()) + } + + fn serialize_u64(self, _v: u64) -> Result<()> { + Err(invalid_number()) + } + + serde_if_integer128! { + fn serialize_u128(self, _v: u128) -> Result<()> { + Err(invalid_number()) + } + } + + fn serialize_f32(self, _v: f32) -> Result<()> { + Err(invalid_number()) + } + + fn serialize_f64(self, _v: f64) -> Result<()> { + Err(invalid_number()) + } + + fn serialize_char(self, _v: char) -> Result<()> { + Err(invalid_number()) + } + + fn serialize_str(self, value: &str) -> Result<()> { + let NumberStrEmitter(serializer) = self; + serializer + .formatter + .write_number_str(&mut serializer.writer, value) + .map_err(|e| PythonSerializerError { message: e.to_string() }) + } + + fn serialize_bytes(self, _value: &[u8]) -> Result<()> { + Err(invalid_number()) + } + + fn serialize_none(self) -> Result<()> { + Err(invalid_number()) + } + + fn serialize_some(self, _value: &T) -> Result<()> + where + T: ?Sized + Serialize, + { + Err(invalid_number()) + } + + fn serialize_unit(self) -> Result<()> { + Err(invalid_number()) + } + + fn serialize_unit_struct(self, _name: &'static str) -> Result<()> { + Err(invalid_number()) + } + + fn serialize_unit_variant(self, _name: &'static str, _variant_index: u32, _variant: &'static str) -> Result<()> { + Err(invalid_number()) + } + + fn serialize_newtype_struct(self, _name: &'static str, _value: &T) -> Result<()> + where + T: ?Sized + Serialize, + { + Err(invalid_number()) + } + + fn serialize_newtype_variant( + self, + _name: &'static str, + _variant_index: u32, + _variant: &'static str, + _value: &T, + ) -> Result<()> + where + T: ?Sized + Serialize, + { + Err(invalid_number()) + } + + fn serialize_seq(self, _len: Option) -> Result { + Err(invalid_number()) + } + + fn serialize_tuple(self, _len: usize) -> Result { + Err(invalid_number()) + } + + fn serialize_tuple_struct(self, _name: &'static str, _len: usize) -> Result { + Err(invalid_number()) + } + + fn serialize_tuple_variant( + self, + _name: &'static str, + _variant_index: u32, + _variant: &'static str, + _len: usize, + ) -> Result { + Err(invalid_number()) + } + + fn serialize_map(self, _len: Option) -> Result { + Err(invalid_number()) + } + + fn serialize_struct(self, _name: &'static str, _len: usize) -> Result { + Err(invalid_number()) + } + + fn serialize_struct_variant( + self, + _name: &'static str, + _variant_index: u32, + _variant: &'static str, + _len: usize, + ) -> Result { + Err(invalid_number()) + } +} diff --git a/src/serializers/shared.rs b/src/serializers/shared.rs index b9b0c1fe1..cfccc748a 100644 --- a/src/serializers/shared.rs +++ b/src/serializers/shared.rs @@ -13,8 +13,9 @@ use serde_json::ser::PrettyFormatter; use crate::build_tools::py_schema_err; use crate::build_tools::py_schema_error_type; -use crate::definitions::{Definitions, DefinitionsBuilder}; +use crate::definitions::DefinitionsBuilder; use crate::py_gc::PyGcTraverse; +use crate::serializers::ser::PythonSerializer; use crate::tools::{py_err, SchemaDict}; use super::errors::se_err_py_err; @@ -112,7 +113,7 @@ combined_serializer! { Nullable: super::type_serializers::nullable::NullableSerializer; Int: super::type_serializers::simple::IntSerializer; Bool: super::type_serializers::simple::BoolSerializer; - Float: super::type_serializers::simple::FloatSerializer; + Float: super::type_serializers::float::FloatSerializer; Decimal: super::type_serializers::decimal::DecimalSerializer; Str: super::type_serializers::string::StrSerializer; Bytes: super::type_serializers::bytes::BytesSerializer; @@ -293,7 +294,7 @@ pub(crate) trait TypeSerializer: Send + Sync + Clone + Debug { fn get_name(&self) -> &str; /// Used by union serializers to decide if it's worth trying again while allowing subclasses - fn retry_with_lax_check(&self, _definitions: &Definitions) -> bool { + fn retry_with_lax_check(&self) -> bool { false } @@ -352,12 +353,12 @@ pub(crate) fn to_json_bytes( Some(indent) => { let indent = vec![b' '; indent]; let formatter = PrettyFormatter::with_indent(&indent); - let mut ser = serde_json::Serializer::with_formatter(writer, formatter); + let mut ser = PythonSerializer::with_formatter(writer, formatter); serializer.serialize(&mut ser).map_err(se_err_py_err)?; ser.into_inner() } None => { - let mut ser = serde_json::Serializer::new(writer); + let mut ser = PythonSerializer::new(writer); serializer.serialize(&mut ser).map_err(se_err_py_err)?; ser.into_inner() } diff --git a/src/serializers/type_serializers/dataclass.rs b/src/serializers/type_serializers/dataclass.rs index 124f962ad..787e267dd 100644 --- a/src/serializers/type_serializers/dataclass.rs +++ b/src/serializers/type_serializers/dataclass.rs @@ -6,7 +6,7 @@ use std::borrow::Cow; use ahash::AHashMap; use crate::build_tools::{py_schema_error_type, ExtraBehavior}; -use crate::definitions::{Definitions, DefinitionsBuilder}; +use crate::definitions::DefinitionsBuilder; use crate::tools::SchemaDict; use super::{ @@ -179,7 +179,7 @@ impl TypeSerializer for DataclassSerializer { &self.name } - fn retry_with_lax_check(&self, _definitions: &Definitions) -> bool { + fn retry_with_lax_check(&self) -> bool { true } } diff --git a/src/serializers/type_serializers/definitions.rs b/src/serializers/type_serializers/definitions.rs index 4614bbc56..b7bf63365 100644 --- a/src/serializers/type_serializers/definitions.rs +++ b/src/serializers/type_serializers/definitions.rs @@ -4,7 +4,7 @@ use pyo3::intern; use pyo3::prelude::*; use pyo3::types::{PyDict, PyList}; -use crate::definitions::Definitions; +use crate::definitions::DefinitionRef; use crate::definitions::DefinitionsBuilder; use crate::tools::SchemaDict; @@ -41,7 +41,7 @@ impl BuildSerializer for DefinitionsSerializerBuilder { #[derive(Debug, Clone)] pub struct DefinitionRefSerializer { - serializer_id: usize, + definition: DefinitionRef, } impl BuildSerializer for DefinitionRefSerializer { @@ -52,9 +52,9 @@ impl BuildSerializer for DefinitionRefSerializer { _config: Option<&PyDict>, definitions: &mut DefinitionsBuilder, ) -> PyResult { - let schema_ref: String = schema.get_as_req(intern!(schema.py(), "schema_ref"))?; - let serializer_id = definitions.get_reference_id(&schema_ref); - Ok(Self { serializer_id }.into()) + let schema_ref = schema.get_as_req(intern!(schema.py(), "schema_ref"))?; + let definition = definitions.get_definition(schema_ref); + Ok(Self { definition }.into()) } } @@ -68,15 +68,15 @@ impl TypeSerializer for DefinitionRefSerializer { exclude: Option<&PyAny>, extra: &Extra, ) -> PyResult { - let value_id = extra.rec_guard.add(value, self.serializer_id)?; - let comb_serializer = extra.definitions.get(self.serializer_id).unwrap(); + let comb_serializer = self.definition.get().unwrap(); + let value_id = extra.rec_guard.add(value, self.definition.id())?; let r = comb_serializer.to_python(value, include, exclude, extra); - extra.rec_guard.pop(value_id, self.serializer_id); + extra.rec_guard.pop(value_id, self.definition.id()); r } fn json_key<'py>(&self, key: &'py PyAny, extra: &Extra) -> PyResult> { - self._invalid_as_json_key(key, extra, Self::EXPECTED_TYPE) + self.definition.get().unwrap().json_key(key, extra) } fn serde_serialize( @@ -87,10 +87,13 @@ impl TypeSerializer for DefinitionRefSerializer { exclude: Option<&PyAny>, extra: &Extra, ) -> Result { - let value_id = extra.rec_guard.add(value, self.serializer_id).map_err(py_err_se_err)?; - let comb_serializer = extra.definitions.get(self.serializer_id).unwrap(); + let comb_serializer = self.definition.get().unwrap(); + let value_id = extra + .rec_guard + .add(value, self.definition.id()) + .map_err(py_err_se_err)?; let r = comb_serializer.serde_serialize(value, serializer, include, exclude, extra); - extra.rec_guard.pop(value_id, self.serializer_id); + extra.rec_guard.pop(value_id, self.definition.id()); r } @@ -98,8 +101,7 @@ impl TypeSerializer for DefinitionRefSerializer { Self::EXPECTED_TYPE } - fn retry_with_lax_check(&self, definitions: &Definitions) -> bool { - let comb_serializer = definitions.get(self.serializer_id).unwrap(); - comb_serializer.retry_with_lax_check(definitions) + fn retry_with_lax_check(&self) -> bool { + self.definition.get().unwrap().retry_with_lax_check() } } diff --git a/src/serializers/type_serializers/dict.rs b/src/serializers/type_serializers/dict.rs index bb2a18633..89851752e 100644 --- a/src/serializers/type_serializers/dict.rs +++ b/src/serializers/type_serializers/dict.rs @@ -43,8 +43,8 @@ impl BuildSerializer for DictSerializer { }; let filter = match schema.get_as::<&PyDict>(intern!(py, "serialization"))? { Some(ser) => { - let include = ser.get_item(intern!(py, "include")); - let exclude = ser.get_item(intern!(py, "exclude")); + let include = ser.get_item(intern!(py, "include"))?; + let exclude = ser.get_item(intern!(py, "exclude"))?; SchemaFilter::from_set_hash(include, exclude)? } None => SchemaFilter::default(), diff --git a/src/serializers/type_serializers/float.rs b/src/serializers/type_serializers/float.rs new file mode 100644 index 000000000..23dcacf1a --- /dev/null +++ b/src/serializers/type_serializers/float.rs @@ -0,0 +1,102 @@ +use pyo3::types::PyDict; +use pyo3::{intern, prelude::*}; + +use std::borrow::Cow; + +use serde::Serializer; + +use crate::definitions::DefinitionsBuilder; +use crate::serializers::config::InfNanMode; +use crate::tools::SchemaDict; + +use super::simple::to_str_json_key; +use super::{ + infer_json_key, infer_serialize, infer_to_python, BuildSerializer, CombinedSerializer, Extra, IsType, ObType, + SerMode, TypeSerializer, +}; + +#[derive(Debug, Clone)] +pub struct FloatSerializer { + inf_nan_mode: InfNanMode, +} + +impl BuildSerializer for FloatSerializer { + const EXPECTED_TYPE: &'static str = "float"; + + fn build( + schema: &PyDict, + config: Option<&PyDict>, + _definitions: &mut DefinitionsBuilder, + ) -> PyResult { + let inf_nan_mode = config + .and_then(|c| c.get_as(intern!(schema.py(), "ser_json_inf_nan")).transpose()) + .transpose()? + .unwrap_or_default(); + Ok(Self { inf_nan_mode }.into()) + } +} + +impl_py_gc_traverse!(FloatSerializer {}); + +impl TypeSerializer for FloatSerializer { + fn to_python( + &self, + value: &PyAny, + include: Option<&PyAny>, + exclude: Option<&PyAny>, + extra: &Extra, + ) -> PyResult { + let py = value.py(); + match extra.ob_type_lookup.is_type(value, ObType::Float) { + IsType::Exact => Ok(value.into_py(py)), + IsType::Subclass => match extra.mode { + SerMode::Json => { + let rust_value = value.extract::()?; + Ok(rust_value.to_object(py)) + } + _ => infer_to_python(value, include, exclude, extra), + }, + IsType::False => { + extra.warnings.on_fallback_py(self.get_name(), value, extra)?; + infer_to_python(value, include, exclude, extra) + } + } + } + + fn json_key<'py>(&self, key: &'py PyAny, extra: &Extra) -> PyResult> { + match extra.ob_type_lookup.is_type(key, ObType::Float) { + IsType::Exact | IsType::Subclass => to_str_json_key(key), + IsType::False => { + extra.warnings.on_fallback_py(self.get_name(), key, extra)?; + infer_json_key(key, extra) + } + } + } + + fn serde_serialize( + &self, + value: &PyAny, + serializer: S, + include: Option<&PyAny>, + exclude: Option<&PyAny>, + extra: &Extra, + ) -> Result { + match value.extract::() { + Ok(v) => { + if (v.is_nan() || v.is_infinite()) && self.inf_nan_mode == InfNanMode::Null { + serializer.serialize_none() + } else { + serializer.serialize_f64(v) + } + } + Err(_) => { + extra.warnings.on_fallback_ser::(self.get_name(), value, extra)?; + infer_serialize(value, serializer, include, exclude, extra) + } + } + } + + fn get_name(&self) -> &str { + Self::EXPECTED_TYPE + } +} diff --git a/src/serializers/type_serializers/list.rs b/src/serializers/type_serializers/list.rs index 4d68ae373..a71e3452a 100644 --- a/src/serializers/type_serializers/list.rs +++ b/src/serializers/type_serializers/list.rs @@ -116,4 +116,8 @@ impl TypeSerializer for ListSerializer { fn get_name(&self) -> &str { &self.name } + + fn retry_with_lax_check(&self) -> bool { + self.item_serializer.retry_with_lax_check() + } } diff --git a/src/serializers/type_serializers/mod.rs b/src/serializers/type_serializers/mod.rs index b942b5b86..decb07aaf 100644 --- a/src/serializers/type_serializers/mod.rs +++ b/src/serializers/type_serializers/mod.rs @@ -5,6 +5,7 @@ pub mod datetime_etc; pub mod decimal; pub mod definitions; pub mod dict; +pub mod float; pub mod format; pub mod function; pub mod generator; diff --git a/src/serializers/type_serializers/model.rs b/src/serializers/type_serializers/model.rs index c5b252fbf..0d2d1d346 100644 --- a/src/serializers/type_serializers/model.rs +++ b/src/serializers/type_serializers/model.rs @@ -13,7 +13,7 @@ use super::{ }; use crate::build_tools::py_schema_err; use crate::build_tools::{py_schema_error_type, ExtraBehavior}; -use crate::definitions::{Definitions, DefinitionsBuilder}; +use crate::definitions::DefinitionsBuilder; use crate::serializers::errors::PydanticSerializationUnexpectedValue; use crate::tools::SchemaDict; @@ -39,7 +39,7 @@ impl BuildSerializer for ModelFieldsBuilder { let fields_dict: &PyDict = schema.get_as_req(intern!(py, "fields"))?; let mut fields: AHashMap = AHashMap::with_capacity(fields_dict.len()); - let extra_serializer = match (schema.get_item(intern!(py, "extras_schema")), &fields_mode) { + let extra_serializer = match (schema.get_item(intern!(py, "extras_schema"))?, &fields_mode) { (Some(v), FieldsMode::ModelExtra) => Some(CombinedSerializer::build(v.extract()?, config, definitions)?), (Some(_), _) => return py_schema_err!("extras_schema can only be used if extra_behavior=allow"), (_, _) => None, @@ -228,7 +228,7 @@ impl TypeSerializer for ModelSerializer { &self.name } - fn retry_with_lax_check(&self, _definitions: &Definitions) -> bool { + fn retry_with_lax_check(&self) -> bool { true } } diff --git a/src/serializers/type_serializers/nullable.rs b/src/serializers/type_serializers/nullable.rs index 837d6c5f1..23349ec81 100644 --- a/src/serializers/type_serializers/nullable.rs +++ b/src/serializers/type_serializers/nullable.rs @@ -4,7 +4,7 @@ use pyo3::intern; use pyo3::prelude::*; use pyo3::types::PyDict; -use crate::definitions::{Definitions, DefinitionsBuilder}; +use crate::definitions::DefinitionsBuilder; use crate::tools::SchemaDict; use super::{infer_json_key_known, BuildSerializer, CombinedSerializer, Extra, IsType, ObType, TypeSerializer}; @@ -75,7 +75,7 @@ impl TypeSerializer for NullableSerializer { Self::EXPECTED_TYPE } - fn retry_with_lax_check(&self, definitions: &Definitions) -> bool { - self.serializer.retry_with_lax_check(definitions) + fn retry_with_lax_check(&self) -> bool { + self.serializer.retry_with_lax_check() } } diff --git a/src/serializers/type_serializers/simple.rs b/src/serializers/type_serializers/simple.rs index f0d90c2bf..dafb2b786 100644 --- a/src/serializers/type_serializers/simple.rs +++ b/src/serializers/type_serializers/simple.rs @@ -180,4 +180,3 @@ pub(crate) fn bool_json_key(key: &PyAny) -> PyResult> { } build_simple_serializer!(BoolSerializer, "bool", bool, ObType::Bool, bool_json_key); -build_simple_serializer!(FloatSerializer, "float", f64, ObType::Float, to_str_json_key); diff --git a/src/serializers/type_serializers/typed_dict.rs b/src/serializers/type_serializers/typed_dict.rs index 5967738ae..fbef3486a 100644 --- a/src/serializers/type_serializers/typed_dict.rs +++ b/src/serializers/type_serializers/typed_dict.rs @@ -35,7 +35,7 @@ impl BuildSerializer for TypedDictBuilder { let fields_dict: &PyDict = schema.get_as_req(intern!(py, "fields"))?; let mut fields: AHashMap = AHashMap::with_capacity(fields_dict.len()); - let extra_serializer = match (schema.get_item(intern!(py, "extras_schema")), &fields_mode) { + let extra_serializer = match (schema.get_item(intern!(py, "extras_schema"))?, &fields_mode) { (Some(v), FieldsMode::TypedDictAllow) => { Some(CombinedSerializer::build(v.extract()?, config, definitions)?) } diff --git a/src/serializers/type_serializers/union.rs b/src/serializers/type_serializers/union.rs index 70818959e..f05e2220e 100644 --- a/src/serializers/type_serializers/union.rs +++ b/src/serializers/type_serializers/union.rs @@ -4,7 +4,7 @@ use pyo3::types::{PyDict, PyList, PyTuple}; use std::borrow::Cow; use crate::build_tools::py_schema_err; -use crate::definitions::{Definitions, DefinitionsBuilder}; +use crate::definitions::DefinitionsBuilder; use crate::tools::SchemaDict; use crate::PydanticSerializationUnexpectedValue; @@ -75,9 +75,10 @@ impl TypeSerializer for UnionSerializer { exclude: Option<&PyAny>, extra: &Extra, ) -> PyResult { - // try the serializers in with error_on fallback=true + // try the serializers in left to right order with error_on fallback=true let mut new_extra = extra.clone(); new_extra.check = SerCheck::Strict; + for comb_serializer in &self.choices { match comb_serializer.to_python(value, include, exclude, &new_extra) { Ok(v) => return Ok(v), @@ -87,7 +88,7 @@ impl TypeSerializer for UnionSerializer { }, } } - if self.retry_with_lax_check(extra.definitions) { + if self.retry_with_lax_check() { new_extra.check = SerCheck::Lax; for comb_serializer in &self.choices { match comb_serializer.to_python(value, include, exclude, &new_extra) { @@ -116,7 +117,7 @@ impl TypeSerializer for UnionSerializer { }, } } - if self.retry_with_lax_check(extra.definitions) { + if self.retry_with_lax_check() { new_extra.check = SerCheck::Lax; for comb_serializer in &self.choices { match comb_serializer.json_key(key, &new_extra) { @@ -153,7 +154,7 @@ impl TypeSerializer for UnionSerializer { }, } } - if self.retry_with_lax_check(extra.definitions) { + if self.retry_with_lax_check() { new_extra.check = SerCheck::Lax; for comb_serializer in &self.choices { match comb_serializer.to_python(value, include, exclude, &new_extra) { @@ -174,10 +175,8 @@ impl TypeSerializer for UnionSerializer { &self.name } - fn retry_with_lax_check(&self, definitions: &Definitions) -> bool { - self.choices - .iter() - .any(|choice| choice.retry_with_lax_check(definitions)) + fn retry_with_lax_check(&self) -> bool { + self.choices.iter().any(CombinedSerializer::retry_with_lax_check) } } diff --git a/src/serializers/type_serializers/with_default.rs b/src/serializers/type_serializers/with_default.rs index 148c05052..d20c273a1 100644 --- a/src/serializers/type_serializers/with_default.rs +++ b/src/serializers/type_serializers/with_default.rs @@ -4,7 +4,7 @@ use pyo3::intern; use pyo3::prelude::*; use pyo3::types::PyDict; -use crate::definitions::{Definitions, DefinitionsBuilder}; +use crate::definitions::DefinitionsBuilder; use crate::tools::SchemaDict; use crate::validators::DefaultType; @@ -67,8 +67,8 @@ impl TypeSerializer for WithDefaultSerializer { Self::EXPECTED_TYPE } - fn retry_with_lax_check(&self, definitions: &Definitions) -> bool { - self.serializer.retry_with_lax_check(definitions) + fn retry_with_lax_check(&self) -> bool { + self.serializer.retry_with_lax_check() } fn get_default(&self, py: Python) -> PyResult> { diff --git a/src/tools.rs b/src/tools.rs index 3c75decf1..af58131f5 100644 --- a/src/tools.rs +++ b/src/tools.rs @@ -20,7 +20,7 @@ impl<'py> SchemaDict<'py> for PyDict { where T: FromPyObject<'py>, { - match self.get_item(key) { + match self.get_item(key)? { Some(t) => Ok(Some(::extract(t)?)), None => Ok(None), } @@ -30,7 +30,7 @@ impl<'py> SchemaDict<'py> for PyDict { where T: FromPyObject<'py>, { - match self.get_item(key) { + match self.get_item(key)? { Some(t) => ::extract(t), None => py_err!(PyKeyError; "{}", key), } diff --git a/src/validators/any.rs b/src/validators/any.rs index eddde1725..2fad89091 100644 --- a/src/validators/any.rs +++ b/src/validators/any.rs @@ -4,7 +4,9 @@ use pyo3::types::PyDict; use crate::errors::ValResult; use crate::input::Input; -use super::{validation_state::ValidationState, BuildValidator, CombinedValidator, DefinitionsBuilder, Validator}; +use super::{ + validation_state::Exactness, BuildValidator, CombinedValidator, DefinitionsBuilder, ValidationState, Validator, +}; /// This might seem useless, but it's useful in DictValidator to avoid Option a lot #[derive(Debug, Clone)] @@ -29,24 +31,14 @@ impl Validator for AnyValidator { &self, py: Python<'data>, input: &'data impl Input<'data>, - _state: &mut ValidationState, + state: &mut ValidationState, ) -> ValResult<'data, PyObject> { + // in a union, Any should be preferred to doing lax coercions + state.floor_exactness(Exactness::Strict); Ok(input.to_object(py)) } - fn different_strict_behavior( - &self, - _definitions: Option<&DefinitionsBuilder>, - _ultra_strict: bool, - ) -> bool { - false - } - fn get_name(&self) -> &str { Self::EXPECTED_TYPE } - - fn complete(&mut self, _definitions: &DefinitionsBuilder) -> PyResult<()> { - Ok(()) - } } diff --git a/src/validators/arguments.rs b/src/validators/arguments.rs index 2c0fe4a0a..7ae65d579 100644 --- a/src/validators/arguments.rs +++ b/src/validators/arguments.rs @@ -6,8 +6,8 @@ use ahash::AHashSet; use crate::build_tools::py_schema_err; use crate::build_tools::schema_or_config_same; -use crate::errors::{ErrorTypeDefaults, ValError, ValLineError, ValResult}; -use crate::input::{GenericArguments, Input}; +use crate::errors::{AsLocItem, ErrorTypeDefaults, ValError, ValLineError, ValResult}; +use crate::input::{GenericArguments, Input, ValidationMatch}; use crate::lookup_key::LookupKey; use crate::tools::SchemaDict; @@ -15,7 +15,7 @@ use crate::tools::SchemaDict; use super::validation_state::ValidationState; use super::{build_validator, BuildValidator, CombinedValidator, DefinitionsBuilder, Validator}; -#[derive(Debug, Clone)] +#[derive(Debug)] struct Parameter { positional: bool, name: String, @@ -24,7 +24,7 @@ struct Parameter { validator: CombinedValidator, } -#[derive(Debug, Clone)] +#[derive(Debug)] pub struct ArgumentsValidator { parameters: Vec, positional_params_count: usize, @@ -66,7 +66,7 @@ impl BuildValidator for ArgumentsValidator { let mut kw_lookup_key = None; let mut kwarg_key = None; if mode == "keyword_only" || mode == "positional_or_keyword" { - kw_lookup_key = match arg.get_item(intern!(py, "alias")) { + kw_lookup_key = match arg.get_item(intern!(py, "alias"))? { Some(alias) => { let alt_alias = if populate_by_name { Some(name.as_str()) } else { None }; Some(LookupKey::from_py(py, alias, alt_alias)?) @@ -110,11 +110,11 @@ impl BuildValidator for ArgumentsValidator { Ok(Self { parameters, positional_params_count, - var_args_validator: match schema.get_item(intern!(py, "var_args_schema")) { + var_args_validator: match schema.get_item(intern!(py, "var_args_schema"))? { Some(v) => Some(Box::new(build_validator(v, config, definitions)?)), None => None, }, - var_kwargs_validator: match schema.get_item(intern!(py, "var_kwargs_schema")) { + var_kwargs_validator: match schema.get_item(intern!(py, "var_kwargs_schema"))? { Some(v) => Some(Box::new(build_validator(v, config, definitions)?)), None => None, }, @@ -282,7 +282,7 @@ impl Validator for ArgumentsValidator { if let Some(kwargs) = $args.kwargs { if kwargs.len() > used_kwargs.len() { for (raw_key, value) in kwargs.iter() { - let either_str = match raw_key.strict_str() { + let either_str = match raw_key.validate_str(true, false).map(ValidationMatch::into_inner) { Ok(k) => k, Err(ValError::LineErrors(line_errors)) => { for err in line_errors { @@ -332,30 +332,7 @@ impl Validator for ArgumentsValidator { } } - fn different_strict_behavior( - &self, - definitions: Option<&DefinitionsBuilder>, - ultra_strict: bool, - ) -> bool { - self.parameters - .iter() - .any(|p| p.validator.different_strict_behavior(definitions, ultra_strict)) - } - fn get_name(&self) -> &str { Self::EXPECTED_TYPE } - - fn complete(&mut self, definitions: &DefinitionsBuilder) -> PyResult<()> { - self.parameters - .iter_mut() - .try_for_each(|parameter| parameter.validator.complete(definitions))?; - if let Some(v) = &mut self.var_args_validator { - v.complete(definitions)?; - } - if let Some(v) = &mut self.var_kwargs_validator { - v.complete(definitions)?; - }; - Ok(()) - } } diff --git a/src/validators/bool.rs b/src/validators/bool.rs index d87c1c1d7..bcd48e991 100644 --- a/src/validators/bool.rs +++ b/src/validators/bool.rs @@ -38,23 +38,12 @@ impl Validator for BoolValidator { ) -> ValResult<'data, PyObject> { // TODO in theory this could be quicker if we used PyBool rather than going to a bool // and back again, might be worth profiling? - let strict = state.strict_or(self.strict); - Ok(input.validate_bool(strict)?.into_py(py)) - } - - fn different_strict_behavior( - &self, - _definitions: Option<&DefinitionsBuilder>, - ultra_strict: bool, - ) -> bool { - !ultra_strict + input + .validate_bool(state.strict_or(self.strict)) + .map(|val_match| val_match.unpack(state).into_py(py)) } fn get_name(&self) -> &str { Self::EXPECTED_TYPE } - - fn complete(&mut self, _definitions: &DefinitionsBuilder) -> PyResult<()> { - Ok(()) - } } diff --git a/src/validators/bytes.rs b/src/validators/bytes.rs index 2084f916e..78a8acb24 100644 --- a/src/validators/bytes.rs +++ b/src/validators/bytes.rs @@ -24,8 +24,8 @@ impl BuildValidator for BytesValidator { _definitions: &mut DefinitionsBuilder, ) -> PyResult { let py = schema.py(); - let use_constrained = schema.get_item(intern!(py, "max_length")).is_some() - || schema.get_item(intern!(py, "min_length")).is_some(); + let use_constrained = schema.get_item(intern!(py, "max_length"))?.is_some() + || schema.get_item(intern!(py, "min_length"))?.is_some(); if use_constrained { BytesConstrainedValidator::build(schema, config) } else { @@ -46,25 +46,14 @@ impl Validator for BytesValidator { input: &'data impl Input<'data>, state: &mut ValidationState, ) -> ValResult<'data, PyObject> { - let either_bytes = input.validate_bytes(state.strict_or(self.strict))?; - Ok(either_bytes.into_py(py)) - } - - fn different_strict_behavior( - &self, - _definitions: Option<&DefinitionsBuilder>, - ultra_strict: bool, - ) -> bool { - !ultra_strict + input + .validate_bytes(state.strict_or(self.strict)) + .map(|m| m.unpack(state).into_py(py)) } fn get_name(&self) -> &str { Self::EXPECTED_TYPE } - - fn complete(&mut self, _definitions: &DefinitionsBuilder) -> PyResult<()> { - Ok(()) - } } #[derive(Debug, Clone)] @@ -83,7 +72,7 @@ impl Validator for BytesConstrainedValidator { input: &'data impl Input<'data>, state: &mut ValidationState, ) -> ValResult<'data, PyObject> { - let either_bytes = input.validate_bytes(state.strict_or(self.strict))?; + let either_bytes = input.validate_bytes(state.strict_or(self.strict))?.unpack(state); let len = either_bytes.len()?; if let Some(min_length) = self.min_length { @@ -108,25 +97,12 @@ impl Validator for BytesConstrainedValidator { )); } } - Ok(either_bytes.into_py(py)) } - fn different_strict_behavior( - &self, - _definitions: Option<&DefinitionsBuilder>, - ultra_strict: bool, - ) -> bool { - !ultra_strict - } - fn get_name(&self) -> &str { "constrained-bytes" } - - fn complete(&mut self, _definitions: &DefinitionsBuilder) -> PyResult<()> { - Ok(()) - } } impl BytesConstrainedValidator { diff --git a/src/validators/call.rs b/src/validators/call.rs index 24c7f4111..e0649aa53 100644 --- a/src/validators/call.rs +++ b/src/validators/call.rs @@ -11,7 +11,7 @@ use crate::tools::SchemaDict; use super::validation_state::ValidationState; use super::{build_validator, BuildValidator, CombinedValidator, DefinitionsBuilder, Validator}; -#[derive(Debug, Clone)] +#[derive(Debug)] pub struct CallValidator { function: PyObject, arguments_validator: Box, @@ -32,7 +32,7 @@ impl BuildValidator for CallValidator { let arguments_schema: &PyAny = schema.get_as_req(intern!(py, "arguments_schema"))?; let arguments_validator = Box::new(build_validator(arguments_schema, config, definitions)?); - let return_schema = schema.get_item(intern!(py, "return_schema")); + let return_schema = schema.get_item(intern!(py, "return_schema"))?; let return_validator = match return_schema { Some(return_schema) => Some(Box::new(build_validator(return_schema, config, definitions)?)), None => None, @@ -98,29 +98,7 @@ impl Validator for CallValidator { } } - fn different_strict_behavior( - &self, - definitions: Option<&DefinitionsBuilder>, - ultra_strict: bool, - ) -> bool { - if let Some(return_validator) = &self.return_validator { - if return_validator.different_strict_behavior(definitions, ultra_strict) { - return true; - } - } - self.arguments_validator - .different_strict_behavior(definitions, ultra_strict) - } - fn get_name(&self) -> &str { &self.name } - - fn complete(&mut self, definitions: &DefinitionsBuilder) -> PyResult<()> { - self.arguments_validator.complete(definitions)?; - match &mut self.return_validator { - Some(v) => v.complete(definitions), - None => Ok(()), - } - } } diff --git a/src/validators/callable.rs b/src/validators/callable.rs index 9b565e3eb..3075e182e 100644 --- a/src/validators/callable.rs +++ b/src/validators/callable.rs @@ -4,6 +4,7 @@ use pyo3::types::PyDict; use crate::errors::{ErrorTypeDefaults, ValError, ValResult}; use crate::input::Input; +use super::validation_state::Exactness; use super::{BuildValidator, CombinedValidator, DefinitionsBuilder, ValidationState, Validator}; #[derive(Debug, Clone)] @@ -28,27 +29,16 @@ impl Validator for CallableValidator { &self, py: Python<'data>, input: &'data impl Input<'data>, - _state: &mut ValidationState, + state: &mut ValidationState, ) -> ValResult<'data, PyObject> { + state.floor_exactness(Exactness::Lax); match input.callable() { true => Ok(input.to_object(py)), false => Err(ValError::new(ErrorTypeDefaults::CallableType, input)), } } - fn different_strict_behavior( - &self, - _definitions: Option<&DefinitionsBuilder>, - _ultra_strict: bool, - ) -> bool { - false - } - fn get_name(&self) -> &str { Self::EXPECTED_TYPE } - - fn complete(&mut self, _definitions: &DefinitionsBuilder) -> PyResult<()> { - Ok(()) - } } diff --git a/src/validators/chain.rs b/src/validators/chain.rs index 001947d1f..d8da86e30 100644 --- a/src/validators/chain.rs +++ b/src/validators/chain.rs @@ -10,7 +10,7 @@ use crate::tools::SchemaDict; use super::validation_state::ValidationState; use super::{build_validator, BuildValidator, CombinedValidator, DefinitionsBuilder, Validator}; -#[derive(Debug, Clone)] +#[derive(Debug)] pub struct ChainValidator { steps: Vec, name: String, @@ -83,21 +83,7 @@ impl Validator for ChainValidator { steps_iter.try_fold(value, |v, step| step.validate(py, v.into_ref(py), state)) } - fn different_strict_behavior( - &self, - definitions: Option<&DefinitionsBuilder>, - ultra_strict: bool, - ) -> bool { - self.steps - .iter() - .any(|v| v.different_strict_behavior(definitions, ultra_strict)) - } - fn get_name(&self) -> &str { &self.name } - - fn complete(&mut self, definitions: &DefinitionsBuilder) -> PyResult<()> { - self.steps.iter_mut().try_for_each(|v| v.complete(definitions)) - } } diff --git a/src/validators/custom_error.rs b/src/validators/custom_error.rs index 1e8258090..4ea31aa5a 100644 --- a/src/validators/custom_error.rs +++ b/src/validators/custom_error.rs @@ -57,7 +57,7 @@ impl CustomError { } } -#[derive(Debug, Clone)] +#[derive(Debug)] pub struct CustomErrorValidator { validator: Box, custom_error: CustomError, @@ -99,19 +99,7 @@ impl Validator for CustomErrorValidator { .map_err(|_| self.custom_error.as_val_error(input)) } - fn different_strict_behavior( - &self, - definitions: Option<&DefinitionsBuilder>, - ultra_strict: bool, - ) -> bool { - self.validator.different_strict_behavior(definitions, ultra_strict) - } - fn get_name(&self) -> &str { &self.name } - - fn complete(&mut self, definitions: &DefinitionsBuilder) -> PyResult<()> { - self.validator.complete(definitions) - } } diff --git a/src/validators/dataclass.rs b/src/validators/dataclass.rs index 117596b9f..1646f5ea6 100644 --- a/src/validators/dataclass.rs +++ b/src/validators/dataclass.rs @@ -7,19 +7,20 @@ use ahash::AHashSet; use crate::build_tools::py_schema_err; use crate::build_tools::{is_strict, schema_or_config_same, ExtraBehavior}; -use crate::errors::{ErrorType, ErrorTypeDefaults, ValError, ValLineError, ValResult}; -use crate::input::{BorrowInput, GenericArguments, Input}; +use crate::errors::{AsLocItem, ErrorType, ErrorTypeDefaults, ValError, ValLineError, ValResult}; +use crate::input::{BorrowInput, GenericArguments, Input, ValidationMatch}; use crate::lookup_key::LookupKey; use crate::tools::SchemaDict; use crate::validators::function::convert_err; use super::arguments::{json_get, json_slice, py_get, py_slice}; use super::model::{create_class, force_setattr, Revalidate}; +use super::validation_state::Exactness; use super::{ build_validator, BuildValidator, CombinedValidator, DefinitionsBuilder, Extra, ValidationState, Validator, }; -#[derive(Debug, Clone)] +#[derive(Debug)] struct Field { kw_only: bool, name: String, @@ -30,7 +31,7 @@ struct Field { frozen: bool, } -#[derive(Debug, Clone)] +#[derive(Debug)] pub struct DataclassArgsValidator { fields: Vec, positional_count: usize, @@ -56,7 +57,7 @@ impl BuildValidator for DataclassArgsValidator { let extra_behavior = ExtraBehavior::from_schema_or_config(py, schema, config, ExtraBehavior::Ignore)?; - let extras_validator = match (schema.get_item(intern!(py, "extras_schema")), &extra_behavior) { + let extras_validator = match (schema.get_item(intern!(py, "extras_schema"))?, &extra_behavior) { (Some(v), ExtraBehavior::Allow) => Some(Box::new(build_validator(v, config, definitions)?)), (Some(_), _) => return py_schema_err!("extras_schema can only be used if extra_behavior=allow"), (_, _) => None, @@ -73,7 +74,7 @@ impl BuildValidator for DataclassArgsValidator { let py_name: &PyString = field.get_as_req(intern!(py, "name"))?; let name: String = py_name.extract()?; - let lookup_key = match field.get_item(intern!(py, "validation_alias")) { + let lookup_key = match field.get_item(intern!(py, "validation_alias"))? { Some(alias) => { let alt_alias = if populate_by_name { Some(name.as_str()) } else { None }; LookupKey::from_py(py, alias, alt_alias)? @@ -232,19 +233,31 @@ impl Validator for DataclassArgsValidator { } // found neither, check if there is a default value, otherwise error (None, None) => { - if let Some(value) = - field - .validator - .default_value(py, Some(field.name.as_str()), state)? - { - set_item!(field, value); - } else { - errors.push(field.lookup_key.error( - ErrorTypeDefaults::Missing, - input, - self.loc_by_alias, - &field.name, - )); + match field.validator.default_value(py, Some(field.name.as_str()), state) { + Ok(Some(value)) => { + // Default value exists, and passed validation if required + set_item!(field, value); + }, + Ok(None) => { + // This means there was no default value + errors.push(field.lookup_key.error( + ErrorTypeDefaults::Missing, + input, + self.loc_by_alias, + &field.name + )); + }, + Err(ValError::Omit) => continue, + Err(ValError::LineErrors(line_errors)) => { + for err in line_errors { + // Note: this will always use the field name even if there is an alias + // However, we don't mind so much because this error can only happen if the + // default value fails validation, which is arguably a developer error. + // We could try to "fix" this in the future if desired. + errors.push(err); + } + } + Err(err) => return Err(err), } } } @@ -269,7 +282,7 @@ impl Validator for DataclassArgsValidator { if let Some(kwargs) = $args.kwargs { if kwargs.len() != used_keys.len() { for (raw_key, value) in kwargs.iter() { - match raw_key.strict_str() { + match raw_key.validate_str(true, false).map(ValidationMatch::into_inner) { Ok(either_str) => { if !used_keys.contains(either_str.as_cow()?.as_ref()) { // Unknown / extra field @@ -426,28 +439,12 @@ impl Validator for DataclassArgsValidator { } } - fn different_strict_behavior( - &self, - definitions: Option<&DefinitionsBuilder>, - ultra_strict: bool, - ) -> bool { - self.fields - .iter() - .any(|f| f.validator.different_strict_behavior(definitions, ultra_strict)) - } - fn get_name(&self) -> &str { &self.validator_name } - - fn complete(&mut self, definitions: &DefinitionsBuilder) -> PyResult<()> { - self.fields - .iter_mut() - .try_for_each(|field| field.validator.complete(definitions)) - } } -#[derive(Debug, Clone)] +#[derive(Debug)] pub struct DataclassValidator { strict: bool, validator: Box, @@ -548,6 +545,7 @@ impl Validator for DataclassValidator { )) } else { let val_output = self.validator.validate(py, input, state)?; + state.floor_exactness(Exactness::Strict); let dc = create_class(self.class.as_ref(py))?; self.set_dict_call(py, dc.as_ref(py), val_output, input)?; Ok(dc) @@ -578,7 +576,7 @@ impl Validator for DataclassValidator { if self.slots { let value = dc_dict - .get_item(field_name) + .get_item(field_name)? .ok_or_else(|| PyKeyError::new_err(field_name.to_string()))?; force_setattr(py, obj, field_name, value)?; } else { @@ -588,25 +586,9 @@ impl Validator for DataclassValidator { Ok(obj.to_object(py)) } - fn different_strict_behavior( - &self, - definitions: Option<&DefinitionsBuilder>, - ultra_strict: bool, - ) -> bool { - if ultra_strict { - self.validator.different_strict_behavior(definitions, ultra_strict) - } else { - true - } - } - fn get_name(&self) -> &str { &self.name } - - fn complete(&mut self, definitions: &DefinitionsBuilder) -> PyResult<()> { - self.validator.complete(definitions) - } } impl DataclassValidator { diff --git a/src/validators/date.rs b/src/validators/date.rs index a771a5045..7c79101f4 100644 --- a/src/validators/date.rs +++ b/src/validators/date.rs @@ -11,6 +11,7 @@ use crate::input::{EitherDate, Input}; use crate::tools::SchemaDict; use crate::validators::datetime::{NowConstraint, NowOp}; +use super::Exactness; use super::{BuildValidator, CombinedValidator, DefinitionsBuilder, ValidationState, Validator}; #[derive(Debug, Clone)] @@ -46,9 +47,12 @@ impl Validator for DateValidator { ) -> ValResult<'data, PyObject> { let strict = state.strict_or(self.strict); let date = match input.validate_date(strict) { - Ok(date) => date, + Ok(val_match) => val_match.unpack(state), // if the error was a parsing error, in lax mode we allow datetimes at midnight - Err(line_errors @ ValError::LineErrors(..)) if !strict => date_from_datetime(input)?.ok_or(line_errors)?, + Err(line_errors @ ValError::LineErrors(..)) if !strict => { + state.floor_exactness(Exactness::Lax); + date_from_datetime(input)?.ok_or(line_errors)? + } Err(otherwise) => return Err(otherwise), }; if let Some(constraints) = &self.constraints { @@ -96,21 +100,9 @@ impl Validator for DateValidator { Ok(date.try_into_py(py)?) } - fn different_strict_behavior( - &self, - _definitions: Option<&DefinitionsBuilder>, - ultra_strict: bool, - ) -> bool { - !ultra_strict - } - fn get_name(&self) -> &str { Self::EXPECTED_TYPE } - - fn complete(&mut self, _definitions: &DefinitionsBuilder) -> PyResult<()> { - Ok(()) - } } /// In lax mode, if the input is not a date, we try parsing the input as a datetime, then check it is an @@ -119,7 +111,7 @@ impl Validator for DateValidator { /// Ok(None) means that this is not relevant to dates (the input was not a datetime nor a string) fn date_from_datetime<'data>(input: &'data impl Input<'data>) -> Result>, ValError<'data>> { let either_dt = match input.validate_datetime(false, speedate::MicrosecondsPrecisionOverflowBehavior::Truncate) { - Ok(dt) => dt, + Ok(val_match) => val_match.into_inner(), // if the error was a parsing error, update the error type from DatetimeParsing to DateFromDatetimeParsing // and return it Err(ValError::LineErrors(mut line_errors)) => { diff --git a/src/validators/datetime.rs b/src/validators/datetime.rs index 7596b7aca..edbd399e7 100644 --- a/src/validators/datetime.rs +++ b/src/validators/datetime.rs @@ -65,7 +65,9 @@ impl Validator for DateTimeValidator { state: &mut ValidationState, ) -> ValResult<'data, PyObject> { let strict = state.strict_or(self.strict); - let datetime = input.validate_datetime(strict, self.microseconds_precision)?; + let datetime = input + .validate_datetime(strict, self.microseconds_precision)? + .unpack(state); if let Some(constraints) = &self.constraints { // if we get an error from as_speedate, it's probably because the input datetime was invalid // specifically had an invalid tzinfo, hence here we return a validation error @@ -125,21 +127,9 @@ impl Validator for DateTimeValidator { Ok(datetime.try_into_py(py)?) } - fn different_strict_behavior( - &self, - _definitions: Option<&DefinitionsBuilder>, - ultra_strict: bool, - ) -> bool { - !ultra_strict - } - fn get_name(&self) -> &str { Self::EXPECTED_TYPE } - - fn complete(&mut self, _definitions: &DefinitionsBuilder) -> PyResult<()> { - Ok(()) - } } #[derive(Debug, Clone)] @@ -260,7 +250,7 @@ impl TZConstraint { pub(super) fn from_py(schema: &PyDict) -> PyResult> { let py = schema.py(); - let tz_constraint = match schema.get_item(intern!(py, "tz_constraint")) { + let tz_constraint = match schema.get_item(intern!(py, "tz_constraint"))? { Some(c) => c, None => return Ok(None), }; diff --git a/src/validators/decimal.rs b/src/validators/decimal.rs index 2564e096a..b9435f046 100644 --- a/src/validators/decimal.rs +++ b/src/validators/decimal.rs @@ -1,7 +1,7 @@ use pyo3::exceptions::{PyTypeError, PyValueError}; +use pyo3::intern; use pyo3::sync::GILOnceCell; use pyo3::types::{IntoPyDict, PyDict, PyTuple, PyType}; -use pyo3::{intern, AsPyPointer}; use pyo3::{prelude::*, PyTypeInfo}; use crate::build_tools::{is_strict, schema_or_config_same}; @@ -83,6 +83,41 @@ impl_py_gc_traverse!(DecimalValidator { gt }); +fn extract_decimal_digits_info<'data>( + decimal: &PyAny, + normalized: bool, + py: Python<'data>, +) -> ValResult<'data, (u64, u64)> { + let mut normalized_decimal: Option<&PyAny> = None; + if normalized { + normalized_decimal = Some(decimal.call_method0(intern!(py, "normalize")).unwrap_or(decimal)); + } + let (_, digit_tuple, exponent): (&PyAny, &PyTuple, &PyAny) = normalized_decimal + .unwrap_or(decimal) + .call_method0(intern!(py, "as_tuple"))? + .extract()?; + + // finite values have numeric exponent, we checked is_finite above + let exponent: i64 = exponent.extract()?; + let mut digits: u64 = u64::try_from(digit_tuple.len()).map_err(|e| ValError::InternalErr(e.into()))?; + let decimals; + if exponent >= 0 { + // A positive exponent adds that many trailing zeros. + digits += exponent as u64; + decimals = 0; + } else { + // If the absolute value of the negative exponent is larger than the + // number of digits, then it's the same as the number of digits, + // because it'll consume all the digits in digit_tuple and then + // add abs(exponent) - len(digit_tuple) leading zeros after the + // decimal point. + decimals = exponent.unsigned_abs(); + digits = digits.max(decimals); + } + + Ok((decimals, digits)) +} + impl Validator for DecimalValidator { fn validate<'data>( &self, @@ -98,65 +133,53 @@ impl Validator for DecimalValidator { } if self.check_digits { - let normalized_value = decimal.call_method0(intern!(py, "normalize")).unwrap_or(decimal); - let (_, digit_tuple, exponent): (&PyAny, &PyTuple, &PyAny) = - normalized_value.call_method0(intern!(py, "as_tuple"))?.extract()?; + if let Ok((normalized_decimals, normalized_digits)) = extract_decimal_digits_info(decimal, true, py) { + if let Ok((decimals, digits)) = extract_decimal_digits_info(decimal, false, py) { + if let Some(max_digits) = self.max_digits { + if (digits > max_digits) & (normalized_digits > max_digits) { + return Err(ValError::new( + ErrorType::DecimalMaxDigits { + max_digits, + context: None, + }, + input, + )); + } + } - // finite values have numeric exponent, we checked is_finite above - let exponent: i64 = exponent.extract()?; - let mut digits: u64 = u64::try_from(digit_tuple.len()).map_err(|e| ValError::InternalErr(e.into()))?; - let decimals; - if exponent >= 0 { - // A positive exponent adds that many trailing zeros. - digits += exponent as u64; - decimals = 0; - } else { - // If the absolute value of the negative exponent is larger than the - // number of digits, then it's the same as the number of digits, - // because it'll consume all the digits in digit_tuple and then - // add abs(exponent) - len(digit_tuple) leading zeros after the - // decimal point. - decimals = exponent.unsigned_abs(); - digits = digits.max(decimals); - } + if let Some(decimal_places) = self.decimal_places { + if (decimals > decimal_places) & (normalized_decimals > decimal_places) { + return Err(ValError::new( + ErrorType::DecimalMaxPlaces { + decimal_places, + context: None, + }, + input, + )); + } - if let Some(max_digits) = self.max_digits { - if digits > max_digits { - return Err(ValError::new( - ErrorType::DecimalMaxDigits { - max_digits, - context: None, - }, - input, - )); - } - } + if let Some(max_digits) = self.max_digits { + let whole_digits = digits.saturating_sub(decimals); + let max_whole_digits = max_digits.saturating_sub(decimal_places); - if let Some(decimal_places) = self.decimal_places { - if decimals > decimal_places { - return Err(ValError::new( - ErrorType::DecimalMaxPlaces { - decimal_places, - context: None, - }, - input, - )); - } + let normalized_whole_digits = normalized_digits.saturating_sub(normalized_decimals); + let normalized_max_whole_digits = max_digits.saturating_sub(decimal_places); - if let Some(max_digits) = self.max_digits { - let whole_digits = digits.saturating_sub(decimals); - let max_whole_digits = max_digits.saturating_sub(decimal_places); - if whole_digits > max_whole_digits { - return Err(ValError::new( - ErrorType::DecimalWholeDigits { - whole_digits: max_whole_digits, - context: None, - }, - input, - )); + if (whole_digits > max_whole_digits) + & (normalized_whole_digits > normalized_max_whole_digits) + { + return Err(ValError::new( + ErrorType::DecimalWholeDigits { + whole_digits: max_whole_digits, + context: None, + }, + input, + )); + } + } } } - } + }; } } @@ -182,8 +205,19 @@ impl Validator for DecimalValidator { } } + // Decimal raises DecimalOperation when comparing NaN, so if it's necessary to compare + // the value to a number, we need to check for NaN first. We cache the result on the first + // time we check it. + let mut is_nan: Option = None; + let mut is_nan = || -> PyResult { + match is_nan { + Some(is_nan) => Ok(is_nan), + None => Ok(*is_nan.insert(decimal.call_method0(intern!(py, "is_nan"))?.extract()?)), + } + }; + if let Some(le) = &self.le { - if !decimal.le(le)? { + if is_nan()? || !decimal.le(le)? { return Err(ValError::new( ErrorType::LessThanEqual { le: Number::String(le.to_string()), @@ -194,7 +228,7 @@ impl Validator for DecimalValidator { } } if let Some(lt) = &self.lt { - if !decimal.lt(lt)? { + if is_nan()? || !decimal.lt(lt)? { return Err(ValError::new( ErrorType::LessThan { lt: Number::String(lt.to_string()), @@ -205,7 +239,7 @@ impl Validator for DecimalValidator { } } if let Some(ge) = &self.ge { - if !decimal.ge(ge)? { + if is_nan()? || !decimal.ge(ge)? { return Err(ValError::new( ErrorType::GreaterThanEqual { ge: Number::String(ge.to_string()), @@ -216,7 +250,7 @@ impl Validator for DecimalValidator { } } if let Some(gt) = &self.gt { - if !decimal.gt(gt)? { + if is_nan()? || !decimal.gt(gt)? { return Err(ValError::new( ErrorType::GreaterThan { gt: Number::String(gt.to_string()), @@ -230,21 +264,9 @@ impl Validator for DecimalValidator { Ok(decimal.into()) } - fn different_strict_behavior( - &self, - _definitions: Option<&DefinitionsBuilder>, - _ultra_strict: bool, - ) -> bool { - true - } - fn get_name(&self) -> &str { Self::EXPECTED_TYPE } - - fn complete(&mut self, _definitions: &DefinitionsBuilder) -> PyResult<()> { - Ok(()) - } } pub(crate) fn create_decimal<'a>( diff --git a/src/validators/definitions.rs b/src/validators/definitions.rs index 3a35fce4c..979278bb9 100644 --- a/src/validators/definitions.rs +++ b/src/validators/definitions.rs @@ -2,6 +2,7 @@ use pyo3::intern; use pyo3::prelude::*; use pyo3::types::{PyDict, PyList}; +use crate::definitions::DefinitionRef; use crate::errors::{ErrorTypeDefaults, ValError, ValResult}; use crate::input::Input; @@ -39,17 +40,12 @@ impl BuildValidator for DefinitionsValidatorBuilder { #[derive(Debug, Clone)] pub struct DefinitionRefValidator { - validator_id: usize, - inner_name: String, - // we have to record the answers to `Question`s as we can't access the validator when `ask()` is called + definition: DefinitionRef, } impl DefinitionRefValidator { - pub fn new(validator_id: usize) -> Self { - Self { - validator_id, - inner_name: "...".to_string(), - } + pub fn new(definition: DefinitionRef) -> Self { + Self { definition } } } @@ -61,15 +57,10 @@ impl BuildValidator for DefinitionRefValidator { _config: Option<&PyDict>, definitions: &mut DefinitionsBuilder, ) -> PyResult { - let schema_ref: String = schema.get_as_req(intern!(schema.py(), "schema_ref"))?; - - let validator_id = definitions.get_reference_id(&schema_ref); + let schema_ref = schema.get_as_req(intern!(schema.py(), "schema_ref"))?; - Ok(Self { - validator_id, - inner_name: "...".to_string(), - } - .into()) + let definition = definitions.get_definition(schema_ref); + Ok(Self::new(definition).into()) } } @@ -82,21 +73,22 @@ impl Validator for DefinitionRefValidator { input: &'data impl Input<'data>, state: &mut ValidationState, ) -> ValResult<'data, PyObject> { + let validator = self.definition.get().unwrap(); if let Some(id) = input.identity() { - if state.recursion_guard.contains_or_insert(id, self.validator_id) { + if state.recursion_guard.contains_or_insert(id, self.definition.id()) { // we don't remove id here, we leave that to the validator which originally added id to `recursion_guard` Err(ValError::new(ErrorTypeDefaults::RecursionLoop, input)) } else { if state.recursion_guard.incr_depth() { return Err(ValError::new(ErrorTypeDefaults::RecursionLoop, input)); } - let output = validate(self.validator_id, py, input, state); - state.recursion_guard.remove(id, self.validator_id); + let output = validator.validate(py, input, state); + state.recursion_guard.remove(id, self.definition.id()); state.recursion_guard.decr_depth(); output } } else { - validate(self.validator_id, py, input, state) + validator.validate(py, input, state) } } @@ -108,69 +100,26 @@ impl Validator for DefinitionRefValidator { field_value: &'data PyAny, state: &mut ValidationState, ) -> ValResult<'data, PyObject> { + let validator = self.definition.get().unwrap(); if let Some(id) = obj.identity() { - if state.recursion_guard.contains_or_insert(id, self.validator_id) { + if state.recursion_guard.contains_or_insert(id, self.definition.id()) { // we don't remove id here, we leave that to the validator which originally added id to `recursion_guard` Err(ValError::new(ErrorTypeDefaults::RecursionLoop, obj)) } else { if state.recursion_guard.incr_depth() { return Err(ValError::new(ErrorTypeDefaults::RecursionLoop, obj)); } - let output = validate_assignment(self.validator_id, py, obj, field_name, field_value, state); - state.recursion_guard.remove(id, self.validator_id); + let output = validator.validate_assignment(py, obj, field_name, field_value, state); + state.recursion_guard.remove(id, self.definition.id()); state.recursion_guard.decr_depth(); output } } else { - validate_assignment(self.validator_id, py, obj, field_name, field_value, state) - } - } - - fn different_strict_behavior( - &self, - definitions: Option<&DefinitionsBuilder>, - ultra_strict: bool, - ) -> bool { - if let Some(definitions) = definitions { - // have to unwrap here, because we can't return an error from this function, should be okay - let validator = definitions.get_definition(self.validator_id).unwrap(); - validator.different_strict_behavior(None, ultra_strict) - } else { - false + validator.validate_assignment(py, obj, field_name, field_value, state) } } fn get_name(&self) -> &str { - &self.inner_name - } - - /// don't need to call complete on the inner validator here, complete_validators takes care of that. - fn complete(&mut self, definitions: &DefinitionsBuilder) -> PyResult<()> { - let validator = definitions.get_definition(self.validator_id)?; - self.inner_name = validator.get_name().to_string(); - Ok(()) + self.definition.get_or_init_name(|v| v.get_name().into()) } } - -fn validate<'data>( - validator_id: usize, - py: Python<'data>, - input: &'data impl Input<'data>, - state: &mut ValidationState, -) -> ValResult<'data, PyObject> { - let validator = state.definitions.get(validator_id).unwrap(); - validator.validate(py, input, state) -} - -#[allow(clippy::too_many_arguments)] -fn validate_assignment<'data>( - validator_id: usize, - py: Python<'data>, - obj: &'data PyAny, - field_name: &'data str, - field_value: &'data PyAny, - state: &mut ValidationState, -) -> ValResult<'data, PyObject> { - let validator = state.definitions.get(validator_id).unwrap(); - validator.validate_assignment(py, obj, field_name, field_value, state) -} diff --git a/src/validators/dict.rs b/src/validators/dict.rs index dc8f03937..3ac284b2a 100644 --- a/src/validators/dict.rs +++ b/src/validators/dict.rs @@ -3,7 +3,7 @@ use pyo3::prelude::*; use pyo3::types::PyDict; use crate::build_tools::is_strict; -use crate::errors::{ValError, ValLineError, ValResult}; +use crate::errors::{AsLocItem, ValError, ValLineError, ValResult}; use crate::input::BorrowInput; use crate::input::{ DictGenericIterator, GenericMapping, Input, JsonObjectGenericIterator, MappingGenericIterator, @@ -16,7 +16,7 @@ use super::any::AnyValidator; use super::list::length_check; use super::{build_validator, BuildValidator, CombinedValidator, DefinitionsBuilder, ValidationState, Validator}; -#[derive(Debug, Clone)] +#[derive(Debug)] pub struct DictValidator { strict: bool, key_validator: Box, @@ -35,11 +35,11 @@ impl BuildValidator for DictValidator { definitions: &mut DefinitionsBuilder, ) -> PyResult { let py = schema.py(); - let key_validator = match schema.get_item(intern!(py, "keys_schema")) { + let key_validator = match schema.get_item(intern!(py, "keys_schema"))? { Some(schema) => Box::new(build_validator(schema, config, definitions)?), None => Box::new(AnyValidator::build(schema, config, definitions)?), }; - let value_validator = match schema.get_item(intern!(py, "values_schema")) { + let value_validator = match schema.get_item(intern!(py, "values_schema"))? { Some(d) => Box::new(build_validator(d, config, definitions)?), None => Box::new(AnyValidator::build(schema, config, definitions)?), }; @@ -80,6 +80,7 @@ impl Validator for DictValidator { self.validate_generic_mapping(py, input, DictGenericIterator::new(py_dict)?, state) } GenericMapping::PyMapping(mapping) => { + state.floor_exactness(super::Exactness::Lax); self.validate_generic_mapping(py, input, MappingGenericIterator::new(mapping)?, state) } GenericMapping::StringMapping(dict) => { @@ -92,27 +93,9 @@ impl Validator for DictValidator { } } - fn different_strict_behavior( - &self, - definitions: Option<&DefinitionsBuilder>, - ultra_strict: bool, - ) -> bool { - if ultra_strict { - self.key_validator.different_strict_behavior(definitions, true) - || self.value_validator.different_strict_behavior(definitions, true) - } else { - true - } - } - fn get_name(&self) -> &str { &self.name } - - fn complete(&mut self, definitions: &DefinitionsBuilder) -> PyResult<()> { - self.key_validator.complete(definitions)?; - self.value_validator.complete(definitions) - } } impl DictValidator { diff --git a/src/validators/float.rs b/src/validators/float.rs index f0eb41750..b72ffafc0 100644 --- a/src/validators/float.rs +++ b/src/validators/float.rs @@ -1,3 +1,5 @@ +use std::cmp::Ordering; + use pyo3::intern; use pyo3::prelude::*; use pyo3::types::PyDict; @@ -19,11 +21,11 @@ impl BuildValidator for FloatBuilder { definitions: &mut DefinitionsBuilder, ) -> PyResult { let py = schema.py(); - let use_constrained = schema.get_item(intern!(py, "multiple_of")).is_some() - || schema.get_item(intern!(py, "le")).is_some() - || schema.get_item(intern!(py, "lt")).is_some() - || schema.get_item(intern!(py, "ge")).is_some() - || schema.get_item(intern!(py, "gt")).is_some(); + let use_constrained = schema.get_item(intern!(py, "multiple_of"))?.is_some() + || schema.get_item(intern!(py, "le"))?.is_some() + || schema.get_item(intern!(py, "lt"))?.is_some() + || schema.get_item(intern!(py, "ge"))?.is_some() + || schema.get_item(intern!(py, "gt"))?.is_some(); if use_constrained { ConstrainedFloatValidator::build(schema, config, definitions) } else { @@ -68,29 +70,16 @@ impl Validator for FloatValidator { input: &'data impl Input<'data>, state: &mut ValidationState, ) -> ValResult<'data, PyObject> { - let strict = state.strict_or(self.strict); - let either_float = input.validate_float(strict, state.extra().ultra_strict)?; + let either_float = input.validate_float(state.strict_or(self.strict))?.unpack(state); if !self.allow_inf_nan && !either_float.as_f64().is_finite() { return Err(ValError::new(ErrorTypeDefaults::FiniteNumber, input)); } Ok(either_float.into_py(py)) } - fn different_strict_behavior( - &self, - _definitions: Option<&DefinitionsBuilder>, - _ultra_strict: bool, - ) -> bool { - true - } - fn get_name(&self) -> &str { Self::EXPECTED_TYPE } - - fn complete(&mut self, _definitions: &DefinitionsBuilder) -> PyResult<()> { - Ok(()) - } } #[derive(Debug, Clone)] @@ -113,8 +102,7 @@ impl Validator for ConstrainedFloatValidator { input: &'data impl Input<'data>, state: &mut ValidationState, ) -> ValResult<'data, PyObject> { - let strict = state.strict_or(self.strict); - let either_float = input.validate_float(strict, state.extra().ultra_strict)?; + let either_float = input.validate_float(state.strict_or(self.strict))?.unpack(state); let float: f64 = either_float.as_f64(); if !self.allow_inf_nan && !float.is_finite() { return Err(ValError::new(ErrorTypeDefaults::FiniteNumber, input)); @@ -133,7 +121,7 @@ impl Validator for ConstrainedFloatValidator { } } if let Some(le) = self.le { - if float > le { + if !matches!(float.partial_cmp(&le), Some(Ordering::Less | Ordering::Equal)) { return Err(ValError::new( ErrorType::LessThanEqual { le: le.into(), @@ -144,7 +132,7 @@ impl Validator for ConstrainedFloatValidator { } } if let Some(lt) = self.lt { - if float >= lt { + if !matches!(float.partial_cmp(<), Some(Ordering::Less)) { return Err(ValError::new( ErrorType::LessThan { lt: lt.into(), @@ -155,7 +143,7 @@ impl Validator for ConstrainedFloatValidator { } } if let Some(ge) = self.ge { - if float < ge { + if !matches!(float.partial_cmp(&ge), Some(Ordering::Greater | Ordering::Equal)) { return Err(ValError::new( ErrorType::GreaterThanEqual { ge: ge.into(), @@ -166,7 +154,7 @@ impl Validator for ConstrainedFloatValidator { } } if let Some(gt) = self.gt { - if float <= gt { + if !matches!(float.partial_cmp(>), Some(Ordering::Greater)) { return Err(ValError::new( ErrorType::GreaterThan { gt: gt.into(), @@ -179,21 +167,9 @@ impl Validator for ConstrainedFloatValidator { Ok(either_float.into_py(py)) } - fn different_strict_behavior( - &self, - _definitions: Option<&DefinitionsBuilder>, - _ultra_strict: bool, - ) -> bool { - true - } - fn get_name(&self) -> &str { "constrained-float" } - - fn complete(&mut self, _definitions: &DefinitionsBuilder) -> PyResult<()> { - Ok(()) - } } impl BuildValidator for ConstrainedFloatValidator { diff --git a/src/validators/frozenset.rs b/src/validators/frozenset.rs index ad7708324..190a8672d 100644 --- a/src/validators/frozenset.rs +++ b/src/validators/frozenset.rs @@ -2,15 +2,16 @@ use pyo3::prelude::*; use pyo3::types::{PyDict, PyFrozenSet}; use crate::errors::ValResult; -use crate::input::Input; +use crate::input::{GenericIterable, Input}; use crate::tools::SchemaDict; +use crate::validators::Exactness; use super::list::min_length_check; use super::set::set_build; use super::validation_state::ValidationState; use super::{BuildValidator, CombinedValidator, DefinitionsBuilder, Validator}; -#[derive(Debug, Clone)] +#[derive(Debug)] pub struct FrozenSetValidator { strict: bool, item_validator: Box, @@ -34,6 +35,12 @@ impl Validator for FrozenSetValidator { state: &mut ValidationState, ) -> ValResult<'data, PyObject> { let collection = input.validate_frozenset(state.strict_or(self.strict))?; + let exactness = match &collection { + GenericIterable::FrozenSet(_) => Exactness::Exact, + GenericIterable::Set(_) | GenericIterable::JsonArray(_) => Exactness::Strict, + _ => Exactness::Lax, + }; + state.floor_exactness(exactness); let f_set = PyFrozenSet::empty(py)?; collection.validate_to_set( py, @@ -48,23 +55,7 @@ impl Validator for FrozenSetValidator { Ok(f_set.into_py(py)) } - fn different_strict_behavior( - &self, - definitions: Option<&DefinitionsBuilder>, - ultra_strict: bool, - ) -> bool { - if ultra_strict { - self.item_validator.different_strict_behavior(definitions, true) - } else { - true - } - } - fn get_name(&self) -> &str { &self.name } - - fn complete(&mut self, definitions: &DefinitionsBuilder) -> PyResult<()> { - self.item_validator.complete(definitions) - } } diff --git a/src/validators/function.rs b/src/validators/function.rs index be0d6374f..4c5ad9c29 100644 --- a/src/validators/function.rs +++ b/src/validators/function.rs @@ -1,14 +1,16 @@ -use pyo3::exceptions::{PyAssertionError, PyTypeError, PyValueError}; +use std::sync::Arc; + +use pyo3::exceptions::{PyAssertionError, PyValueError}; use pyo3::prelude::*; use pyo3::types::{PyAny, PyDict, PyString}; use pyo3::{intern, PyTraverseError, PyVisit}; use crate::errors::{ - ErrorType, LocItem, PydanticCustomError, PydanticKnownError, PydanticOmit, ValError, ValResult, ValidationError, + AsLocItem, ErrorType, PydanticCustomError, PydanticKnownError, PydanticOmit, ValError, ValResult, ValidationError, }; use crate::input::Input; use crate::py_gc::PyGcTraverse; -use crate::tools::{function_name, py_err, safe_repr, SchemaDict}; +use crate::tools::{function_name, safe_repr, SchemaDict}; use crate::PydanticUseDefault; use super::generator::InternalValidator; @@ -111,31 +113,14 @@ macro_rules! impl_validator { self._validate(validate, py, obj, state) } - fn different_strict_behavior( - &self, - definitions: Option<&DefinitionsBuilder>, - ultra_strict: bool, - ) -> bool { - if ultra_strict { - self.validator - .different_strict_behavior(definitions, ultra_strict) - } else { - true - } - } - fn get_name(&self) -> &str { &self.name } - - fn complete(&mut self, definitions: &DefinitionsBuilder) -> PyResult<()> { - self.validator.complete(definitions) - } } }; } -#[derive(Debug, Clone)] +#[derive(Debug)] pub struct FunctionBeforeValidator { validator: Box, func: PyObject, @@ -168,7 +153,7 @@ impl FunctionBeforeValidator { impl_validator!(FunctionBeforeValidator); -#[derive(Debug, Clone)] +#[derive(Debug)] pub struct FunctionAfterValidator { validator: Box, func: PyObject, @@ -255,27 +240,14 @@ impl Validator for FunctionPlainValidator { r.map_err(|e| convert_err(py, e, input)) } - fn different_strict_behavior( - &self, - _definitions: Option<&DefinitionsBuilder>, - ultra_strict: bool, - ) -> bool { - // best guess, should we change this? - !ultra_strict - } - fn get_name(&self) -> &str { &self.name } - - fn complete(&mut self, _definitions: &DefinitionsBuilder) -> PyResult<()> { - Ok(()) - } } -#[derive(Debug, Clone)] +#[derive(Debug)] pub struct FunctionWrapValidator { - validator: Box, + validator: Arc, func: PyObject, config: PyObject, name: String, @@ -299,7 +271,7 @@ impl BuildValidator for FunctionWrapValidator { let hide_input_in_errors: bool = config.get_as(intern!(py, "hide_input_in_errors"))?.unwrap_or(false); let validation_error_cause: bool = config.get_as(intern!(py, "validation_error_cause"))?.unwrap_or(false); Ok(Self { - validator: Box::new(validator), + validator: Arc::new(validator), func: function_info.function.clone(), config: match config { Some(c) => c.into(), @@ -350,18 +322,16 @@ impl Validator for FunctionWrapValidator { validator: InternalValidator::new( py, "ValidatorCallable", - &self.validator, + self.validator.clone(), state, self.hide_input_in_errors, self.validation_error_cause, ), }; - self._validate( - Py::new(py, handler)?.into_ref(py), - py, - input.to_object(py).into_ref(py), - state, - ) + let handler = Py::new(py, handler)?.into_ref(py); + let result = self._validate(handler, py, input.to_object(py).into_ref(py), state); + state.exactness = handler.borrow_mut().validator.exactness; + result } fn validate_assignment<'data>( @@ -376,7 +346,7 @@ impl Validator for FunctionWrapValidator { validator: InternalValidator::new( py, "AssignmentValidatorCallable", - &self.validator, + self.validator.clone(), state, self.hide_input_in_errors, self.validation_error_cause, @@ -387,29 +357,13 @@ impl Validator for FunctionWrapValidator { self._validate(Py::new(py, handler)?.into_ref(py), py, obj, state) } - fn different_strict_behavior( - &self, - definitions: Option<&DefinitionsBuilder>, - ultra_strict: bool, - ) -> bool { - if ultra_strict { - self.validator.different_strict_behavior(definitions, ultra_strict) - } else { - true - } - } - fn get_name(&self) -> &str { &self.name } - - fn complete(&mut self, definitions: &DefinitionsBuilder) -> PyResult<()> { - self.validator.complete(definitions) - } } #[pyclass(module = "pydantic_core._pydantic_core")] -#[derive(Debug, Clone)] +#[derive(Debug)] struct ValidatorCallable { validator: InternalValidator, } @@ -417,13 +371,7 @@ struct ValidatorCallable { #[pymethods] impl ValidatorCallable { fn __call__(&mut self, py: Python, input_value: &PyAny, outer_location: Option<&PyAny>) -> PyResult { - let outer_location = match outer_location { - Some(ol) => match LocItem::try_from(ol) { - Ok(ol) => Some(ol), - Err(_) => return py_err!(PyTypeError; "outer_location must be a str or int"), - }, - None => None, - }; + let outer_location = outer_location.map(AsLocItem::as_loc_item); self.validator.validate(py, input_value, outer_location) } @@ -441,7 +389,7 @@ impl ValidatorCallable { } #[pyclass(module = "pydantic_core._pydantic_core")] -#[derive(Debug, Clone)] +#[derive(Debug)] struct AssignmentValidatorCallable { updated_field_name: String, updated_field_value: Py, @@ -451,13 +399,7 @@ struct AssignmentValidatorCallable { #[pymethods] impl AssignmentValidatorCallable { fn __call__(&mut self, py: Python, input_value: &PyAny, outer_location: Option<&PyAny>) -> PyResult { - let outer_location = match outer_location { - Some(ol) => match LocItem::try_from(ol) { - Ok(ol) => Some(ol), - Err(_) => return py_err!(PyTypeError; "outer_location must be a str or int"), - }, - None => None, - }; + let outer_location = outer_location.map(AsLocItem::as_loc_item); self.validator.validate_assignment( py, input_value, diff --git a/src/validators/generator.rs b/src/validators/generator.rs index 0cff7e28e..94497d228 100644 --- a/src/validators/generator.rs +++ b/src/validators/generator.rs @@ -1,4 +1,5 @@ use std::fmt; +use std::sync::Arc; use pyo3::prelude::*; use pyo3::types::PyDict; @@ -10,11 +11,13 @@ use crate::tools::SchemaDict; use crate::ValidationError; use super::list::get_items_schema; -use super::{BuildValidator, CombinedValidator, DefinitionsBuilder, Extra, InputType, ValidationState, Validator}; +use super::{ + BuildValidator, CombinedValidator, DefinitionsBuilder, Exactness, Extra, InputType, ValidationState, Validator, +}; #[derive(Debug, Clone)] pub struct GeneratorValidator { - item_validator: Option>, + item_validator: Option>, min_length: Option, max_length: Option, name: String, @@ -30,7 +33,7 @@ impl BuildValidator for GeneratorValidator { config: Option<&PyDict>, definitions: &mut DefinitionsBuilder, ) -> PyResult { - let item_validator = get_items_schema(schema, config, definitions)?; + let item_validator = get_items_schema(schema, config, definitions)?.map(Arc::new); let name = match item_validator { Some(ref v) => format!("{}[{}]", Self::EXPECTED_TYPE, v.get_name()), None => format!("{}[any]", Self::EXPECTED_TYPE), @@ -67,7 +70,7 @@ impl Validator for GeneratorValidator { InternalValidator::new( py, "ValidatorIterator", - v, + v.clone(), state, self.hide_input_in_errors, self.validation_error_cause, @@ -85,32 +88,13 @@ impl Validator for GeneratorValidator { Ok(v_iterator.into_py(py)) } - fn different_strict_behavior( - &self, - definitions: Option<&DefinitionsBuilder>, - ultra_strict: bool, - ) -> bool { - if let Some(ref v) = self.item_validator { - v.different_strict_behavior(definitions, ultra_strict) - } else { - false - } - } - fn get_name(&self) -> &str { &self.name } - - fn complete(&mut self, definitions: &DefinitionsBuilder) -> PyResult<()> { - match self.item_validator { - Some(ref mut v) => v.complete(definitions), - None => Ok(()), - } - } } #[pyclass(module = "pydantic_core._pydantic_core")] -#[derive(Debug, Clone)] +#[derive(Debug)] struct ValidatorIterator { iterator: GenericIterator, validator: Option, @@ -217,13 +201,11 @@ impl ValidatorIterator { } } -/// Cloneable validator wrapper for use in generators in functions, this can be passed back to python +/// Owned validator wrapper for use in generators in functions, this can be passed back to python /// mid-validation -#[derive(Clone)] pub struct InternalValidator { name: String, - validator: CombinedValidator, - definitions: Vec, + validator: Arc, // TODO, do we need data? data: Option>, strict: Option, @@ -231,6 +213,7 @@ pub struct InternalValidator { context: Option, self_instance: Option, recursion_guard: RecursionGuard, + pub(crate) exactness: Option, validation_mode: InputType, hide_input_in_errors: bool, validation_error_cause: bool, @@ -246,7 +229,7 @@ impl InternalValidator { pub fn new( py: Python, name: &str, - validator: &CombinedValidator, + validator: Arc, state: &ValidationState, hide_input_in_errors: bool, validation_error_cause: bool, @@ -254,14 +237,14 @@ impl InternalValidator { let extra = state.extra(); Self { name: name.to_string(), - validator: validator.clone(), - definitions: state.definitions.to_vec(), + validator, data: extra.data.map(|d| d.into_py(py)), strict: extra.strict, from_attributes: extra.from_attributes, context: extra.context.map(|d| d.into_py(py)), self_instance: extra.self_instance.map(|d| d.into_py(py)), recursion_guard: state.recursion_guard.clone(), + exactness: state.exactness, validation_mode: extra.input_type, hide_input_in_errors, validation_error_cause, @@ -280,13 +263,14 @@ impl InternalValidator { input_type: self.validation_mode, data: self.data.as_ref().map(|data| data.as_ref(py)), strict: self.strict, - ultra_strict: false, from_attributes: self.from_attributes, context: self.context.as_ref().map(|data| data.as_ref(py)), self_instance: self.self_instance.as_ref().map(|data| data.as_ref(py)), }; - let mut state = ValidationState::new(extra, &self.definitions, &mut self.recursion_guard); - self.validator + let mut state = ValidationState::new(extra, &mut self.recursion_guard); + state.exactness = self.exactness; + let result = self + .validator .validate_assignment(py, model, field_name, field_value, &mut state) .map_err(|e| { ValidationError::from_val_error( @@ -298,7 +282,9 @@ impl InternalValidator { self.hide_input_in_errors, self.validation_error_cause, ) - }) + }); + self.exactness = state.exactness; + result } pub fn validate<'data>( @@ -311,13 +297,13 @@ impl InternalValidator { input_type: self.validation_mode, data: self.data.as_ref().map(|data| data.as_ref(py)), strict: self.strict, - ultra_strict: false, from_attributes: self.from_attributes, context: self.context.as_ref().map(|data| data.as_ref(py)), self_instance: self.self_instance.as_ref().map(|data| data.as_ref(py)), }; - let mut state = ValidationState::new(extra, &self.definitions, &mut self.recursion_guard); - self.validator.validate(py, input, &mut state).map_err(|e| { + let mut state = ValidationState::new(extra, &mut self.recursion_guard); + state.exactness = self.exactness; + let result = self.validator.validate(py, input, &mut state).map_err(|e| { ValidationError::from_val_error( py, self.name.to_object(py), @@ -327,13 +313,14 @@ impl InternalValidator { self.hide_input_in_errors, self.validation_error_cause, ) - }) + }); + self.exactness = state.exactness; + result } } impl_py_gc_traverse!(InternalValidator { validator, - definitions, data, context, self_instance diff --git a/src/validators/int.rs b/src/validators/int.rs index 3fba2199d..dabfb5115 100644 --- a/src/validators/int.rs +++ b/src/validators/int.rs @@ -8,8 +8,7 @@ use crate::errors::{ErrorType, ValError, ValResult}; use crate::input::{Input, Int}; use crate::tools::SchemaDict; -use super::ValidationState; -use super::{BuildValidator, CombinedValidator, DefinitionsBuilder, Validator}; +use super::{BuildValidator, CombinedValidator, DefinitionsBuilder, ValidationState, Validator}; #[derive(Debug, Clone)] pub struct IntValidator { @@ -25,11 +24,11 @@ impl BuildValidator for IntValidator { _definitions: &mut DefinitionsBuilder, ) -> PyResult { let py = schema.py(); - let use_constrained = schema.get_item(intern!(py, "multiple_of")).is_some() - || schema.get_item(intern!(py, "le")).is_some() - || schema.get_item(intern!(py, "lt")).is_some() - || schema.get_item(intern!(py, "ge")).is_some() - || schema.get_item(intern!(py, "gt")).is_some(); + let use_constrained = schema.get_item(intern!(py, "multiple_of"))?.is_some() + || schema.get_item(intern!(py, "le"))?.is_some() + || schema.get_item(intern!(py, "lt"))?.is_some() + || schema.get_item(intern!(py, "ge"))?.is_some() + || schema.get_item(intern!(py, "gt"))?.is_some(); if use_constrained { ConstrainedIntValidator::build(schema, config) } else { @@ -50,25 +49,14 @@ impl Validator for IntValidator { input: &'data impl Input<'data>, state: &mut ValidationState, ) -> ValResult<'data, PyObject> { - let either_int = input.validate_int(state.strict_or(self.strict))?; - Ok(either_int.into_py(py)) - } - - fn different_strict_behavior( - &self, - _definitions: Option<&DefinitionsBuilder>, - ultra_strict: bool, - ) -> bool { - !ultra_strict + input + .validate_int(state.strict_or(self.strict)) + .map(|val_match| val_match.unpack(state).into_py(py)) } fn get_name(&self) -> &str { Self::EXPECTED_TYPE } - - fn complete(&mut self, _definitions: &DefinitionsBuilder) -> PyResult<()> { - Ok(()) - } } #[derive(Debug, Clone)] @@ -90,7 +78,7 @@ impl Validator for ConstrainedIntValidator { input: &'data impl Input<'data>, state: &mut ValidationState, ) -> ValResult<'data, PyObject> { - let either_int = input.validate_int(state.strict_or(self.strict))?; + let either_int = input.validate_int(state.strict_or(self.strict))?.unpack(state); let int_value = either_int.as_int()?; if let Some(ref multiple_of) = self.multiple_of { @@ -151,21 +139,9 @@ impl Validator for ConstrainedIntValidator { Ok(either_int.into_py(py)) } - fn different_strict_behavior( - &self, - _definitions: Option<&DefinitionsBuilder>, - ultra_strict: bool, - ) -> bool { - !ultra_strict - } - fn get_name(&self) -> &str { "constrained-int" } - - fn complete(&mut self, _definitions: &DefinitionsBuilder) -> PyResult<()> { - Ok(()) - } } impl ConstrainedIntValidator { diff --git a/src/validators/is_instance.rs b/src/validators/is_instance.rs index 78705482c..189589d6a 100644 --- a/src/validators/is_instance.rs +++ b/src/validators/is_instance.rs @@ -8,8 +8,7 @@ use crate::errors::{ErrorType, ValError, ValResult}; use crate::input::Input; use crate::tools::SchemaDict; -use super::ValidationState; -use super::{BuildValidator, CombinedValidator, DefinitionsBuilder, Validator}; +use super::{BuildValidator, CombinedValidator, DefinitionsBuilder, ValidationState, Validator}; #[derive(Debug, Clone)] pub struct IsInstanceValidator { @@ -83,19 +82,7 @@ impl Validator for IsInstanceValidator { } } - fn different_strict_behavior( - &self, - _definitions: Option<&DefinitionsBuilder>, - _ultra_strict: bool, - ) -> bool { - false - } - fn get_name(&self) -> &str { &self.name } - - fn complete(&mut self, _definitions: &DefinitionsBuilder) -> PyResult<()> { - Ok(()) - } } diff --git a/src/validators/is_subclass.rs b/src/validators/is_subclass.rs index d0f5a6cfe..7a89ef36c 100644 --- a/src/validators/is_subclass.rs +++ b/src/validators/is_subclass.rs @@ -6,8 +6,7 @@ use crate::errors::{ErrorType, ValError, ValResult}; use crate::input::Input; use crate::tools::SchemaDict; -use super::ValidationState; -use super::{BuildValidator, CombinedValidator, DefinitionsBuilder, Validator}; +use super::{BuildValidator, CombinedValidator, DefinitionsBuilder, ValidationState, Validator}; #[derive(Debug, Clone)] pub struct IsSubclassValidator { @@ -62,19 +61,7 @@ impl Validator for IsSubclassValidator { } } - fn different_strict_behavior( - &self, - _definitions: Option<&DefinitionsBuilder>, - _ultra_strict: bool, - ) -> bool { - false - } - fn get_name(&self) -> &str { &self.name } - - fn complete(&mut self, _definitions: &DefinitionsBuilder) -> PyResult<()> { - Ok(()) - } } diff --git a/src/validators/json.rs b/src/validators/json.rs index 5eda007be..9dfb5fae2 100644 --- a/src/validators/json.rs +++ b/src/validators/json.rs @@ -6,10 +6,9 @@ use crate::errors::ValResult; use crate::input::Input; use crate::tools::SchemaDict; -use super::ValidationState; -use super::{build_validator, BuildValidator, CombinedValidator, DefinitionsBuilder, Validator}; +use super::{build_validator, BuildValidator, CombinedValidator, DefinitionsBuilder, ValidationState, Validator}; -#[derive(Debug, Clone)] +#[derive(Debug)] pub struct JsonValidator { validator: Option>, name: String, @@ -61,26 +60,7 @@ impl Validator for JsonValidator { } } - fn different_strict_behavior( - &self, - definitions: Option<&DefinitionsBuilder>, - ultra_strict: bool, - ) -> bool { - if let Some(ref v) = self.validator { - v.different_strict_behavior(definitions, ultra_strict) - } else { - false - } - } - fn get_name(&self) -> &str { &self.name } - - fn complete(&mut self, definitions: &DefinitionsBuilder) -> PyResult<()> { - match self.validator { - Some(ref mut v) => v.complete(definitions), - None => Ok(()), - } - } } diff --git a/src/validators/json_or_python.rs b/src/validators/json_or_python.rs index 828532fe5..302cbdaf6 100644 --- a/src/validators/json_or_python.rs +++ b/src/validators/json_or_python.rs @@ -7,11 +7,9 @@ use crate::errors::ValResult; use crate::input::Input; use crate::tools::SchemaDict; -use super::InputType; -use super::ValidationState; -use super::{build_validator, BuildValidator, CombinedValidator, Validator}; +use super::{build_validator, BuildValidator, CombinedValidator, InputType, ValidationState, Validator}; -#[derive(Debug, Clone)] +#[derive(Debug)] pub struct JsonOrPython { json: Box, python: Box, @@ -63,21 +61,7 @@ impl Validator for JsonOrPython { } } - fn different_strict_behavior( - &self, - definitions: Option<&DefinitionsBuilder>, - ultra_strict: bool, - ) -> bool { - self.json.different_strict_behavior(definitions, ultra_strict) - || self.python.different_strict_behavior(definitions, ultra_strict) - } - fn get_name(&self) -> &str { &self.name } - - fn complete(&mut self, definitions: &DefinitionsBuilder) -> PyResult<()> { - self.json.complete(definitions)?; - self.python.complete(definitions) - } } diff --git a/src/validators/lax_or_strict.rs b/src/validators/lax_or_strict.rs index 9681cf689..78021cd4c 100644 --- a/src/validators/lax_or_strict.rs +++ b/src/validators/lax_or_strict.rs @@ -7,10 +7,11 @@ use crate::errors::ValResult; use crate::input::Input; use crate::tools::SchemaDict; +use super::Exactness; use super::ValidationState; use super::{build_validator, BuildValidator, CombinedValidator, DefinitionsBuilder, Validator}; -#[derive(Debug, Clone)] +#[derive(Debug)] pub struct LaxOrStrictValidator { strict: bool, lax_validator: Box, @@ -64,28 +65,20 @@ impl Validator for LaxOrStrictValidator { if state.strict_or(self.strict) { self.strict_validator.validate(py, input, state) } else { + // horrible edge case: if doing smart union validation, we need to try the strict validator + // anyway and prefer that if it succeeds + if state.exactness.is_some() { + if let Ok(strict_result) = self.strict_validator.validate(py, input, state) { + return Ok(strict_result); + } + // this is now known to be not strict + state.floor_exactness(Exactness::Lax); + } self.lax_validator.validate(py, input, state) } } - fn different_strict_behavior( - &self, - definitions: Option<&DefinitionsBuilder>, - ultra_strict: bool, - ) -> bool { - if ultra_strict { - self.strict_validator.different_strict_behavior(definitions, true) - } else { - true - } - } - fn get_name(&self) -> &str { &self.name } - - fn complete(&mut self, definitions: &DefinitionsBuilder) -> PyResult<()> { - self.lax_validator.complete(definitions)?; - self.strict_validator.complete(definitions) - } } diff --git a/src/validators/list.rs b/src/validators/list.rs index ffd7a118e..b2e0ff116 100644 --- a/src/validators/list.rs +++ b/src/validators/list.rs @@ -1,32 +1,35 @@ +use std::sync::OnceLock; + use pyo3::prelude::*; use pyo3::types::PyDict; use crate::errors::ValResult; use crate::input::{GenericIterable, Input}; use crate::tools::SchemaDict; +use crate::validators::Exactness; use super::{build_validator, BuildValidator, CombinedValidator, DefinitionsBuilder, ValidationState, Validator}; -#[derive(Debug, Clone)] +#[derive(Debug)] pub struct ListValidator { strict: bool, item_validator: Option>, min_length: Option, max_length: Option, - name: String, + name: OnceLock, } pub fn get_items_schema( schema: &PyDict, config: Option<&PyDict>, definitions: &mut DefinitionsBuilder, -) -> PyResult>> { - match schema.get_item(pyo3::intern!(schema.py(), "items_schema")) { +) -> PyResult> { + match schema.get_item(pyo3::intern!(schema.py(), "items_schema"))? { Some(d) => { let validator = build_validator(d, config, definitions)?; match validator { CombinedValidator::Any(_) => Ok(None), - _ => Ok(Some(Box::new(validator))), + _ => Ok(Some(validator)), } } None => Ok(None), @@ -98,15 +101,13 @@ impl BuildValidator for ListValidator { definitions: &mut DefinitionsBuilder, ) -> PyResult { let py = schema.py(); - let item_validator = get_items_schema(schema, config, definitions)?; - let inner_name = item_validator.as_ref().map_or("any", |v| v.get_name()); - let name = format!("{}[{inner_name}]", Self::EXPECTED_TYPE); + let item_validator = get_items_schema(schema, config, definitions)?.map(Box::new); Ok(Self { strict: crate::build_tools::is_strict(schema, config)?, item_validator, min_length: schema.get_as(pyo3::intern!(py, "min_length"))?, max_length: schema.get_as(pyo3::intern!(py, "max_length"))?, - name, + name: OnceLock::new(), } .into()) } @@ -122,6 +123,12 @@ impl Validator for ListValidator { state: &mut ValidationState, ) -> ValResult<'data, PyObject> { let seq = input.validate_list(state.strict_or(self.strict))?; + let exactness = match &seq { + GenericIterable::List(_) | GenericIterable::JsonArray(_) => Exactness::Exact, + GenericIterable::Tuple(_) => Exactness::Strict, + _ => Exactness::Lax, + }; + state.floor_exactness(exactness); let output = match self.item_validator { Some(ref v) => seq.validate_to_vec(py, input, self.max_length, "List", v, state)?, @@ -138,31 +145,22 @@ impl Validator for ListValidator { Ok(output.into_py(py)) } - fn different_strict_behavior( - &self, - definitions: Option<&DefinitionsBuilder>, - ultra_strict: bool, - ) -> bool { - if ultra_strict { - match self.item_validator { - Some(ref v) => v.different_strict_behavior(definitions, true), - None => false, - } - } else { - true - } - } - fn get_name(&self) -> &str { - &self.name - } - - fn complete(&mut self, definitions: &DefinitionsBuilder) -> PyResult<()> { - if let Some(ref mut v) = self.item_validator { - v.complete(definitions)?; - let inner_name = v.get_name(); - self.name = format!("{}[{inner_name}]", Self::EXPECTED_TYPE); + // The logic here is a little janky, it's done to try to cache the formatted name + // while also trying to render definitions correctly when possible. + // + // Probably an opportunity for a future refactor + match self.name.get() { + Some(s) => s.as_str(), + None => { + let name = self.item_validator.as_ref().map_or("any", |v| v.get_name()); + if name == "..." { + // when inner name is not initialized yet, don't cache it here + "list[...]" + } else { + self.name.get_or_init(|| format!("list[{name}]")).as_str() + } + } } - Ok(()) } } diff --git a/src/validators/literal.rs b/src/validators/literal.rs index de394affb..c9a846695 100644 --- a/src/validators/literal.rs +++ b/src/validators/literal.rs @@ -22,7 +22,7 @@ struct BoolLiteral { } #[derive(Debug, Clone)] -pub struct LiteralLookup { +pub struct LiteralLookup { // Specialized lookups for ints, bools and strings because they // (1) are easy to convert between Rust and Python // (2) hashing them in Rust is very fast @@ -35,7 +35,7 @@ pub struct LiteralLookup { pub values: Vec, } -impl LiteralLookup { +impl LiteralLookup { pub fn new<'py>(py: Python<'py>, expected: impl Iterator) -> PyResult { let mut expected_int = AHashMap::new(); let mut expected_str: AHashMap = AHashMap::new(); @@ -48,8 +48,8 @@ impl LiteralLookup { for (k, v) in expected { let id = values.len(); values.push(v); - if let Ok(bool) = k.strict_bool() { - if bool { + if let Ok(bool) = k.validate_bool(true) { + if bool.into_inner() { expected_bool.true_id = Some(id); } else { expected_bool.false_id = Some(id); @@ -97,8 +97,8 @@ impl LiteralLookup { input: &'data I, ) -> ValResult<'data, Option<(&'data I, &T)>> { if let Some(expected_bool) = &self.expected_bool { - if let Ok(bool_value) = input.strict_bool() { - if bool_value { + if let Ok(bool_value) = input.validate_bool(true) { + if bool_value.into_inner() { if let Some(true_value) = &expected_bool.true_id { return Ok(Some((input, &self.values[*true_value]))); } @@ -126,7 +126,7 @@ impl LiteralLookup { } // must be an enum or bytes if let Some(expected_py) = &self.expected_py { - if let Some(v) = expected_py.as_ref(py).get_item(input) { + if let Some(v) = expected_py.as_ref(py).get_item(input)? { let id: usize = v.extract().unwrap(); return Ok(Some((input, &self.values[id]))); } @@ -135,7 +135,7 @@ impl LiteralLookup { } } -impl PyGcTraverse for LiteralLookup { +impl PyGcTraverse for LiteralLookup { fn py_gc_traverse(&self, visit: &PyVisit<'_>) -> Result<(), PyTraverseError> { self.expected_py.py_gc_traverse(visit)?; self.values.py_gc_traverse(visit)?; @@ -198,21 +198,9 @@ impl Validator for LiteralValidator { } } - fn different_strict_behavior( - &self, - _definitions: Option<&DefinitionsBuilder>, - ultra_strict: bool, - ) -> bool { - !ultra_strict - } - fn get_name(&self) -> &str { &self.name } - - fn complete(&mut self, _definitions: &DefinitionsBuilder) -> PyResult<()> { - Ok(()) - } } pub fn expected_repr_name(mut repr_args: Vec, base_name: &'static str) -> (String, String) { diff --git a/src/validators/mod.rs b/src/validators/mod.rs index 4ee677663..f541ea45d 100644 --- a/src/validators/mod.rs +++ b/src/validators/mod.rs @@ -9,7 +9,7 @@ use pyo3::types::{PyAny, PyDict, PyTuple, PyType}; use pyo3::{intern, PyTraverseError, PyVisit}; use crate::build_tools::{py_schema_err, py_schema_error_type, SchemaError}; -use crate::definitions::DefinitionsBuilder; +use crate::definitions::{Definitions, DefinitionsBuilder}; use crate::errors::{LocItem, ValError, ValResult, ValidationError}; use crate::input::{Input, InputType, StringMapping}; use crate::py_gc::PyGcTraverse; @@ -58,10 +58,9 @@ mod uuid; mod validation_state; mod with_default; +pub use self::validation_state::{Exactness, ValidationState}; pub use with_default::DefaultType; -pub use self::validation_state::ValidationState; - #[pyclass(module = "pydantic_core._pydantic_core", name = "Some")] pub struct PySome { #[pyo3(get)] @@ -97,12 +96,15 @@ impl PySome { } } -#[pyclass(module = "pydantic_core._pydantic_core")] -#[derive(Debug, Clone)] +#[pyclass(module = "pydantic_core._pydantic_core", frozen)] +#[derive(Debug)] pub struct SchemaValidator { validator: CombinedValidator, - definitions: Vec, - schema: PyObject, + definitions: Definitions, + // References to the Python schema and config objects are saved to enable + // reconstructing the object for cloudpickle support (see `__reduce__`). + py_schema: Py, + py_config: Option>, #[pyo3(get)] title: PyObject, hide_input_in_errors: bool, @@ -115,14 +117,15 @@ impl SchemaValidator { pub fn py_new(py: Python, schema: &PyAny, config: Option<&PyDict>) -> PyResult { let mut definitions_builder = DefinitionsBuilder::new(); - let mut validator = build_validator(schema, config, &mut definitions_builder)?; - validator.complete(&definitions_builder)?; - let mut definitions = definitions_builder.clone().finish()?; - for val in &mut definitions { - val.complete(&definitions_builder)?; - } + let validator = build_validator(schema, config, &mut definitions_builder)?; + let definitions = definitions_builder.finish()?; + let py_schema = schema.into_py(py); + let py_config = match config { + Some(c) if !c.is_empty() => Some(c.into_py(py)), + _ => None, + }; let config_title = match config { - Some(c) => c.get_item("title"), + Some(c) => c.get_item("title")?, None => None, }; let title = match config_title { @@ -134,17 +137,20 @@ impl SchemaValidator { Ok(Self { validator, definitions, - schema: schema.into_py(py), + py_schema, + py_config, title, hide_input_in_errors, validation_error_cause, }) } - pub fn __reduce__(&self, py: Python) -> PyResult { - let args = (self.schema.as_ref(py),); - let cls = Py::new(py, self.clone())?.getattr(py, "__class__")?; - Ok((cls, args).into_py(py)) + pub fn __reduce__(slf: &PyCell) -> PyResult<(PyObject, (PyObject, PyObject))> { + // Enables support for `pickle` serialization. + let py = slf.py(); + let cls = slf.get_type().into(); + let init_args = (slf.get().py_schema.to_object(py), slf.get().py_config.to_object(py)); + Ok((cls, init_args)) } #[pyo3(signature = (input, *, strict=None, from_attributes=None, context=None, self_instance=None))] @@ -260,13 +266,12 @@ impl SchemaValidator { data: None, strict, from_attributes, - ultra_strict: false, context, self_instance: None, }; let guard = &mut RecursionGuard::default(); - let mut state = ValidationState::new(extra, &self.definitions, guard); + let mut state = ValidationState::new(extra, guard); self.validator .validate_assignment(py, obj, field_name, field_value, &mut state) .map_err(|e| self.prepare_validation_err(py, e, InputType::Python)) @@ -279,12 +284,11 @@ impl SchemaValidator { data: None, strict, from_attributes: None, - ultra_strict: false, context, self_instance: None, }; let recursion_guard = &mut RecursionGuard::default(); - let mut state = ValidationState::new(extra, &self.definitions, recursion_guard); + let mut state = ValidationState::new(extra, recursion_guard); let r = self.validator.default_value(py, None::, &mut state); match r { Ok(maybe_default) => match maybe_default { @@ -306,9 +310,9 @@ impl SchemaValidator { fn __traverse__(&self, visit: PyVisit<'_>) -> Result<(), PyTraverseError> { self.validator.py_gc_traverse(&visit)?; - visit.call(&self.schema)?; - for slot in &self.definitions { - slot.py_gc_traverse(&visit)?; + visit.call(&self.py_schema)?; + if let Some(ref py_config) = self.py_config { + visit.call(py_config)?; } Ok(()) } @@ -332,7 +336,6 @@ impl SchemaValidator { { let mut state = ValidationState::new( Extra::new(strict, from_attributes, context, self_instance, input_type), - &self.definitions, recursion_guard, ); self.validator.validate(py, input, &mut state) @@ -371,7 +374,6 @@ impl<'py> SelfValidator<'py> { let mut recursion_guard = RecursionGuard::default(); let mut state = ValidationState::new( Extra::new(strict, None, None, None, InputType::Python), - &self.validator.definitions, &mut recursion_guard, ); match self.validator.validator.validate(py, schema, &mut state) { @@ -388,19 +390,16 @@ impl<'py> SelfValidator<'py> { let mut definitions_builder = DefinitionsBuilder::new(); - let mut validator = match build_validator(self_schema, None, &mut definitions_builder) { + let validator = match build_validator(self_schema, None, &mut definitions_builder) { Ok(v) => v, Err(err) => return py_schema_err!("Error building self-schema:\n {}", err), }; - validator.complete(&definitions_builder)?; - let mut definitions = definitions_builder.clone().finish()?; - for val in &mut definitions { - val.complete(&definitions_builder)?; - } + let definitions = definitions_builder.finish()?; Ok(SchemaValidator { validator, definitions, - schema: py.None(), + py_schema: py.None(), + py_config: None, title: "Self Schema".into_py(py), hide_input_in_errors: false, validation_error_cause: false, @@ -559,8 +558,6 @@ pub struct Extra<'a> { pub data: Option<&'a PyDict>, /// whether we're in strict or lax mode pub strict: Option, - /// whether we're in ultra-strict mode, only used occasionally in unions - pub ultra_strict: bool, /// Validation time setting of `from_attributes` pub from_attributes: Option, /// context used in validator functions @@ -581,7 +578,6 @@ impl<'a> Extra<'a> { input_type, data: None, strict, - ultra_strict: false, from_attributes, context, self_instance, @@ -590,12 +586,11 @@ impl<'a> Extra<'a> { } impl<'a> Extra<'a> { - pub fn as_strict(&self, ultra_strict: bool) -> Self { + pub fn as_strict(&self) -> Self { Self { input_type: self.input_type, data: self.data, strict: Some(true), - ultra_strict, from_attributes: self.from_attributes, context: self.context, self_instance: self.self_instance, @@ -603,7 +598,7 @@ impl<'a> Extra<'a> { } } -#[derive(Debug, Clone)] +#[derive(Debug)] #[enum_dispatch(PyGcTraverse)] pub enum CombinedValidator { // typed dict e.g. heterogeneous dicts or simply a model @@ -699,7 +694,7 @@ pub enum CombinedValidator { /// This trait must be implemented by all validators, it allows various validators to be accessed consistently, /// validators defined in `build_validator` also need `EXPECTED_TYPE` as a const, but that can't be part of the trait #[enum_dispatch(CombinedValidator)] -pub trait Validator: Send + Sync + Clone + Debug { +pub trait Validator: Send + Sync + Debug { /// Do the actual validation for this schema/type fn validate<'data>( &self, @@ -732,19 +727,7 @@ pub trait Validator: Send + Sync + Clone + Debug { Err(py_err.into()) } - /// whether the validator behaves differently in strict mode, and in ultra strict mode - /// implementations should return true if any of their sub-validators return true - fn different_strict_behavior( - &self, - definitions: Option<&DefinitionsBuilder>, - ultra_strict: bool, - ) -> bool; - /// `get_name` generally returns `Self::EXPECTED_TYPE` or some other clear identifier of the validator /// this is used in the error location in unions, and in the top level message in `ValidationError` fn get_name(&self) -> &str; - - /// this method must be implemented for any validator which holds references to other validators, - /// it is used by `DefinitionRefValidator` to set its name - fn complete(&mut self, _definitions: &DefinitionsBuilder) -> PyResult<()>; } diff --git a/src/validators/model.rs b/src/validators/model.rs index 2ec7185a9..0299ce5d8 100644 --- a/src/validators/model.rs +++ b/src/validators/model.rs @@ -7,6 +7,7 @@ use pyo3::types::{PyDict, PySet, PyString, PyTuple, PyType}; use pyo3::{ffi, intern}; use super::function::convert_err; +use super::validation_state::Exactness; use super::{ build_validator, BuildValidator, CombinedValidator, DefinitionsBuilder, Extra, ValidationState, Validator, }; @@ -50,7 +51,7 @@ impl Revalidate { } } -#[derive(Debug, Clone)] +#[derive(Debug)] pub struct ModelValidator { revalidate: Revalidate, validator: Box, @@ -143,6 +144,8 @@ impl Validator for ModelValidator { Ok(input.to_object(py)) } } else { + // Having to construct a new model is not an exact match + state.floor_exactness(Exactness::Strict); self.validate_construct(py, input, None, state) } } @@ -206,25 +209,9 @@ impl Validator for ModelValidator { Ok(model.into_py(py)) } - fn different_strict_behavior( - &self, - definitions: Option<&DefinitionsBuilder>, - ultra_strict: bool, - ) -> bool { - if ultra_strict { - self.validator.different_strict_behavior(definitions, ultra_strict) - } else { - true - } - } - fn get_name(&self) -> &str { &self.name } - - fn complete(&mut self, definitions: &DefinitionsBuilder) -> PyResult<()> { - self.validator.complete(definitions) - } } impl ModelValidator { diff --git a/src/validators/model_fields.rs b/src/validators/model_fields.rs index f2654c33e..a284bd4e9 100644 --- a/src/validators/model_fields.rs +++ b/src/validators/model_fields.rs @@ -7,20 +7,21 @@ use ahash::AHashSet; use crate::build_tools::py_schema_err; use crate::build_tools::{is_strict, schema_or_config_same, ExtraBehavior}; -use crate::errors::{ErrorType, ErrorTypeDefaults, ValError, ValLineError, ValResult}; +use crate::errors::{AsLocItem, ErrorType, ErrorTypeDefaults, ValError, ValLineError, ValResult}; use crate::input::{ AttributesGenericIterator, BorrowInput, DictGenericIterator, GenericMapping, Input, JsonObjectGenericIterator, - MappingGenericIterator, StringMappingGenericIterator, + MappingGenericIterator, StringMappingGenericIterator, ValidationMatch, }; use crate::lookup_key::LookupKey; use crate::tools::SchemaDict; -use super::ValidationState; -use super::{build_validator, BuildValidator, CombinedValidator, DefinitionsBuilder, Extra, Validator}; +use super::{ + build_validator, BuildValidator, CombinedValidator, DefinitionsBuilder, Extra, ValidationState, Validator, +}; use std::ops::ControlFlow; -#[derive(Debug, Clone)] +#[derive(Debug)] struct Field { name: String, lookup_key: LookupKey, @@ -31,7 +32,7 @@ struct Field { impl_py_gc_traverse!(Field { validator }); -#[derive(Debug, Clone)] +#[derive(Debug)] pub struct ModelFieldsValidator { fields: Vec, model_name: String, @@ -58,7 +59,7 @@ impl BuildValidator for ModelFieldsValidator { let extra_behavior = ExtraBehavior::from_schema_or_config(py, schema, config, ExtraBehavior::Ignore)?; - let extras_validator = match (schema.get_item(intern!(py, "extras_schema")), &extra_behavior) { + let extras_validator = match (schema.get_item(intern!(py, "extras_schema"))?, &extra_behavior) { (Some(v), ExtraBehavior::Allow) => Some(Box::new(build_validator(v, config, definitions)?)), (Some(_), _) => return py_schema_err!("extras_schema can only be used if extra_behavior=allow"), (_, _) => None, @@ -81,7 +82,7 @@ impl BuildValidator for ModelFieldsValidator { Err(err) => return py_schema_err!("Field \"{}\":\n {}", field_name, err), }; - let lookup_key = match field_info.get_item(intern!(py, "validation_alias")) { + let lookup_key = match field_info.get_item(intern!(py, "validation_alias"))? { Some(alias) => { let alt_alias = if populate_by_name { Some(field_name) } else { None }; LookupKey::from_py(py, alias, alt_alias)? @@ -211,15 +212,33 @@ impl Validator for ModelFieldsValidator { Err(err) => return ControlFlow::Break(err.into_owned(py)), } continue; - } else if let Some(value) = control_flow!(field.validator.default_value(py, Some(field.name.as_str()), state))? { - control_flow!(model_dict.set_item(&field.name_py, value))?; - } else { - errors.push(field.lookup_key.error( - ErrorTypeDefaults::Missing, - input, - self.loc_by_alias, - &field.name - )); + } + + match field.validator.default_value(py, Some(field.name.as_str()), state) { + Ok(Some(value)) => { + // Default value exists, and passed validation if required + control_flow!(model_dict.set_item(&field.name_py, value))?; + }, + Ok(None) => { + // This means there was no default value + errors.push(field.lookup_key.error( + ErrorTypeDefaults::Missing, + input, + self.loc_by_alias, + &field.name + )); + }, + Err(ValError::Omit) => continue, + Err(ValError::LineErrors(line_errors)) => { + for err in line_errors { + // Note: this will always use the field name even if there is an alias + // However, we don't mind so much because this error can only happen if the + // default value fails validation, which is arguably a developer error. + // We could try to "fix" this in the future if desired. + errors.push(err); + } + } + Err(err) => return ControlFlow::Break(err), } } ControlFlow::Continue(()) @@ -232,7 +251,7 @@ impl Validator for ModelFieldsValidator { let model_extra_dict = PyDict::new(py); for item_result in <$iter>::new($dict)? { let (raw_key, value) = item_result?; - let either_str = match raw_key.strict_str() { + let either_str = match raw_key.validate_str(true, false).map(ValidationMatch::into_inner) { Ok(k) => k, Err(ValError::LineErrors(line_errors)) => { for err in line_errors { @@ -415,27 +434,7 @@ impl Validator for ModelFieldsValidator { Ok((new_data.to_object(py), new_extra, fields_set.to_object(py)).to_object(py)) } - fn different_strict_behavior( - &self, - definitions: Option<&DefinitionsBuilder>, - ultra_strict: bool, - ) -> bool { - self.fields - .iter() - .any(|f| f.validator.different_strict_behavior(definitions, ultra_strict)) - } - fn get_name(&self) -> &str { Self::EXPECTED_TYPE } - - fn complete(&mut self, definitions: &DefinitionsBuilder) -> PyResult<()> { - self.fields - .iter_mut() - .try_for_each(|f| f.validator.complete(definitions))?; - match &mut self.extras_validator { - Some(v) => v.complete(definitions), - None => Ok(()), - } - } } diff --git a/src/validators/none.rs b/src/validators/none.rs index 36be70acb..f6891292b 100644 --- a/src/validators/none.rs +++ b/src/validators/none.rs @@ -36,19 +36,7 @@ impl Validator for NoneValidator { } } - fn different_strict_behavior( - &self, - _definitions: Option<&DefinitionsBuilder>, - _ultra_strict: bool, - ) -> bool { - false - } - fn get_name(&self) -> &str { Self::EXPECTED_TYPE } - - fn complete(&mut self, _definitions: &DefinitionsBuilder) -> PyResult<()> { - Ok(()) - } } diff --git a/src/validators/nullable.rs b/src/validators/nullable.rs index 4b408f206..85fbd6c26 100644 --- a/src/validators/nullable.rs +++ b/src/validators/nullable.rs @@ -9,7 +9,7 @@ use crate::tools::SchemaDict; use super::ValidationState; use super::{build_validator, BuildValidator, CombinedValidator, DefinitionsBuilder, Validator}; -#[derive(Debug, Clone)] +#[derive(Debug)] pub struct NullableValidator { validator: Box, name: String, @@ -45,19 +45,7 @@ impl Validator for NullableValidator { } } - fn different_strict_behavior( - &self, - definitions: Option<&DefinitionsBuilder>, - ultra_strict: bool, - ) -> bool { - self.validator.different_strict_behavior(definitions, ultra_strict) - } - fn get_name(&self) -> &str { &self.name } - - fn complete(&mut self, definitions: &DefinitionsBuilder) -> PyResult<()> { - self.validator.complete(definitions) - } } diff --git a/src/validators/set.rs b/src/validators/set.rs index e5e2cecf3..d29c60c3f 100644 --- a/src/validators/set.rs +++ b/src/validators/set.rs @@ -2,13 +2,14 @@ use pyo3::prelude::*; use pyo3::types::{PyDict, PySet}; use crate::errors::ValResult; -use crate::input::Input; +use crate::input::{GenericIterable, Input}; use crate::tools::SchemaDict; +use crate::validators::Exactness; use super::list::min_length_check; use super::{BuildValidator, CombinedValidator, DefinitionsBuilder, ValidationState, Validator}; -#[derive(Debug, Clone)] +#[derive(Debug)] pub struct SetValidator { strict: bool, item_validator: Box, @@ -25,7 +26,7 @@ macro_rules! set_build { definitions: &mut DefinitionsBuilder, ) -> PyResult { let py = schema.py(); - let item_validator = match schema.get_item(pyo3::intern!(schema.py(), "items_schema")) { + let item_validator = match schema.get_item(pyo3::intern!(schema.py(), "items_schema"))? { Some(d) => Box::new(crate::validators::build_validator(d, config, definitions)?), None => Box::new(crate::validators::any::AnyValidator::build( schema, @@ -64,29 +65,19 @@ impl Validator for SetValidator { state: &mut ValidationState, ) -> ValResult<'data, PyObject> { let collection = input.validate_set(state.strict_or(self.strict))?; + let exactness = match &collection { + GenericIterable::Set(_) => Exactness::Exact, + GenericIterable::FrozenSet(_) | GenericIterable::JsonArray(_) => Exactness::Strict, + _ => Exactness::Lax, + }; + state.floor_exactness(exactness); let set = PySet::empty(py)?; collection.validate_to_set(py, set, input, self.max_length, "Set", &self.item_validator, state)?; min_length_check!(input, "Set", self.min_length, set); Ok(set.into_py(py)) } - fn different_strict_behavior( - &self, - definitions: Option<&DefinitionsBuilder>, - ultra_strict: bool, - ) -> bool { - if ultra_strict { - self.item_validator.different_strict_behavior(definitions, true) - } else { - true - } - } - fn get_name(&self) -> &str { &self.name } - - fn complete(&mut self, definitions: &DefinitionsBuilder) -> PyResult<()> { - self.item_validator.complete(definitions) - } } diff --git a/src/validators/string.rs b/src/validators/string.rs index 6b646224d..98d8a9d99 100644 --- a/src/validators/string.rs +++ b/src/validators/string.rs @@ -10,7 +10,7 @@ use crate::tools::SchemaDict; use super::{BuildValidator, CombinedValidator, DefinitionsBuilder, ValidationState, Validator}; -#[derive(Debug, Clone)] +#[derive(Debug)] pub struct StrValidator { strict: bool, coerce_numbers_to_str: bool, @@ -47,25 +47,14 @@ impl Validator for StrValidator { input: &'data impl Input<'data>, state: &mut ValidationState, ) -> ValResult<'data, PyObject> { - let either_str = input.validate_str(state.strict_or(self.strict), self.coerce_numbers_to_str)?; - Ok(either_str.into_py(py)) - } - - fn different_strict_behavior( - &self, - _definitions: Option<&DefinitionsBuilder>, - ultra_strict: bool, - ) -> bool { - !ultra_strict + input + .validate_str(state.strict_or(self.strict), self.coerce_numbers_to_str) + .map(|val_match| val_match.unpack(state).into_py(py)) } fn get_name(&self) -> &str { Self::EXPECTED_TYPE } - - fn complete(&mut self, _definitions: &DefinitionsBuilder) -> PyResult<()> { - Ok(()) - } } /// Any new properties set here must be reflected in `has_constraints_set` @@ -90,7 +79,9 @@ impl Validator for StrConstrainedValidator { input: &'data impl Input<'data>, state: &mut ValidationState, ) -> ValResult<'data, PyObject> { - let either_str = input.validate_str(state.strict_or(self.strict), self.coerce_numbers_to_str)?; + let either_str = input + .validate_str(state.strict_or(self.strict), self.coerce_numbers_to_str)? + .unpack(state); let cow = either_str.as_cow()?; let mut str = cow.as_ref(); if self.strip_whitespace { @@ -150,21 +141,9 @@ impl Validator for StrConstrainedValidator { Ok(py_string.into_py(py)) } - fn different_strict_behavior( - &self, - _definitions: Option<&DefinitionsBuilder>, - ultra_strict: bool, - ) -> bool { - !ultra_strict - } - fn get_name(&self) -> &str { "constrained-str" } - - fn complete(&mut self, _definitions: &DefinitionsBuilder) -> PyResult<()> { - Ok(()) - } } impl StrConstrainedValidator { @@ -197,10 +176,10 @@ impl StrConstrainedValidator { let to_upper: bool = schema_or_config(schema, config, intern!(py, "to_upper"), intern!(py, "str_to_upper"))?.unwrap_or(false); - let coerce_numbers_to_str = config - .and_then(|c| c.get_item("coerce_numbers_to_str")) - .and_then(|v| v.is_true().ok()) - .unwrap_or(false); + let coerce_numbers_to_str = match config { + Some(c) => c.get_item("coerce_numbers_to_str")?.map_or(Ok(false), PyAny::is_true)?, + None => false, + }; Ok(Self { strict: is_strict(schema, config)?, diff --git a/src/validators/time.rs b/src/validators/time.rs index 7bbd7e511..abf82091f 100644 --- a/src/validators/time.rs +++ b/src/validators/time.rs @@ -46,7 +46,9 @@ impl Validator for TimeValidator { input: &'data impl Input<'data>, state: &mut ValidationState, ) -> ValResult<'data, PyObject> { - let time = input.validate_time(state.strict_or(self.strict), self.microseconds_precision)?; + let time = input + .validate_time(state.strict_or(self.strict), self.microseconds_precision)? + .unpack(state); if let Some(constraints) = &self.constraints { let raw_time = time.as_raw()?; @@ -78,21 +80,9 @@ impl Validator for TimeValidator { Ok(time.try_into_py(py)?) } - fn different_strict_behavior( - &self, - _definitions: Option<&DefinitionsBuilder>, - ultra_strict: bool, - ) -> bool { - !ultra_strict - } - fn get_name(&self) -> &str { Self::EXPECTED_TYPE } - - fn complete(&mut self, _definitions: &DefinitionsBuilder) -> PyResult<()> { - Ok(()) - } } fn convert_pytime(schema: &PyDict, field: &PyString) -> PyResult> { diff --git a/src/validators/timedelta.rs b/src/validators/timedelta.rs index 106d5a64a..f04fef91c 100644 --- a/src/validators/timedelta.rs +++ b/src/validators/timedelta.rs @@ -25,7 +25,7 @@ struct TimedeltaConstraints { } fn get_constraint(schema: &PyDict, key: &str) -> PyResult> { - match schema.get_item(key) { + match schema.get_item(key)? { Some(value) => { let either_timedelta = EitherTimedelta::try_from(value)?; Ok(Some(either_timedelta.to_duration()?)) @@ -71,7 +71,9 @@ impl Validator for TimeDeltaValidator { input: &'data impl Input<'data>, state: &mut ValidationState, ) -> ValResult<'data, PyObject> { - let timedelta = input.validate_timedelta(state.strict_or(self.strict), self.microseconds_precision)?; + let timedelta = input + .validate_timedelta(state.strict_or(self.strict), self.microseconds_precision)? + .unpack(state); let py_timedelta = timedelta.try_into_py(py)?; if let Some(constraints) = &self.constraints { let raw_timedelta = timedelta.to_duration()?; @@ -101,21 +103,9 @@ impl Validator for TimeDeltaValidator { Ok(py_timedelta.into()) } - fn different_strict_behavior( - &self, - _definitions: Option<&DefinitionsBuilder>, - ultra_strict: bool, - ) -> bool { - !ultra_strict - } - fn get_name(&self) -> &str { Self::EXPECTED_TYPE } - - fn complete(&mut self, _definitions: &DefinitionsBuilder) -> PyResult<()> { - Ok(()) - } } fn pydelta_to_human_readable(py_delta: &PyDelta) -> String { let total_seconds = py_delta.get_seconds(); diff --git a/src/validators/tuple.rs b/src/validators/tuple.rs index 5c2c09bec..9513582e5 100644 --- a/src/validators/tuple.rs +++ b/src/validators/tuple.rs @@ -6,11 +6,12 @@ use crate::build_tools::is_strict; use crate::errors::{ErrorType, ErrorTypeDefaults, ValError, ValLineError, ValResult}; use crate::input::{GenericIterable, Input}; use crate::tools::SchemaDict; +use crate::validators::Exactness; use super::list::{get_items_schema, min_length_check}; use super::{build_validator, BuildValidator, CombinedValidator, DefinitionsBuilder, ValidationState, Validator}; -#[derive(Debug, Clone)] +#[derive(Debug)] pub struct TupleVariableValidator { strict: bool, item_validator: Option>, @@ -27,7 +28,7 @@ impl BuildValidator for TupleVariableValidator { definitions: &mut DefinitionsBuilder, ) -> PyResult { let py = schema.py(); - let item_validator = get_items_schema(schema, config, definitions)?; + let item_validator = get_items_schema(schema, config, definitions)?.map(Box::new); let inner_name = item_validator.as_ref().map_or("any", |v| v.get_name()); let name = format!("tuple[{inner_name}, ...]"); Ok(Self { @@ -51,6 +52,12 @@ impl Validator for TupleVariableValidator { state: &mut ValidationState, ) -> ValResult<'data, PyObject> { let seq = input.validate_tuple(state.strict_or(self.strict))?; + let exactness = match &seq { + GenericIterable::Tuple(_) | GenericIterable::JsonArray(_) => Exactness::Exact, + GenericIterable::List(_) => Exactness::Strict, + _ => Exactness::Lax, + }; + state.floor_exactness(exactness); let output = match self.item_validator { Some(ref v) => seq.validate_to_vec(py, input, self.max_length, "Tuple", v, state)?, @@ -60,34 +67,12 @@ impl Validator for TupleVariableValidator { Ok(PyTuple::new(py, &output).into_py(py)) } - fn different_strict_behavior( - &self, - definitions: Option<&DefinitionsBuilder>, - ultra_strict: bool, - ) -> bool { - if ultra_strict { - match self.item_validator { - Some(ref v) => v.different_strict_behavior(definitions, true), - None => false, - } - } else { - true - } - } - fn get_name(&self) -> &str { &self.name } - - fn complete(&mut self, definitions: &DefinitionsBuilder) -> PyResult<()> { - match self.item_validator { - Some(ref mut v) => v.complete(definitions), - None => Ok(()), - } - } } -#[derive(Debug, Clone)] +#[derive(Debug)] pub struct TuplePositionalValidator { strict: bool, items_validators: Vec, @@ -117,7 +102,7 @@ impl BuildValidator for TuplePositionalValidator { Ok(Self { strict: is_strict(schema, config)?, items_validators: validators, - extras_validator: match schema.get_item(intern!(py, "extras_schema")) { + extras_validator: match schema.get_item(intern!(py, "extras_schema"))? { Some(v) => Some(Box::new(build_validator(v, config, definitions)?)), None => None, }, @@ -203,6 +188,13 @@ impl Validator for TuplePositionalValidator { state: &mut ValidationState, ) -> ValResult<'data, PyObject> { let collection = input.validate_tuple(state.strict_or(self.strict))?; + let exactness: crate::validators::Exactness = match &collection { + GenericIterable::Tuple(_) | GenericIterable::JsonArray(_) => Exactness::Exact, + GenericIterable::List(_) => Exactness::Strict, + _ => Exactness::Lax, + }; + state.floor_exactness(exactness); + let actual_length = collection.generic_len(); let expected_length = if self.extras_validator.is_some() { actual_length.unwrap_or(self.items_validators.len()) @@ -242,39 +234,7 @@ impl Validator for TuplePositionalValidator { } } - fn different_strict_behavior( - &self, - definitions: Option<&DefinitionsBuilder>, - ultra_strict: bool, - ) -> bool { - if ultra_strict { - if self - .items_validators - .iter() - .any(|v| v.different_strict_behavior(definitions, true)) - { - true - } else if let Some(ref v) = self.extras_validator { - v.different_strict_behavior(definitions, true) - } else { - false - } - } else { - true - } - } - fn get_name(&self) -> &str { &self.name } - - fn complete(&mut self, definitions: &DefinitionsBuilder) -> PyResult<()> { - self.items_validators - .iter_mut() - .try_for_each(|v| v.complete(definitions))?; - match &mut self.extras_validator { - Some(v) => v.complete(definitions), - None => Ok(()), - } - } } diff --git a/src/validators/typed_dict.rs b/src/validators/typed_dict.rs index 56e4a8225..f55b7d717 100644 --- a/src/validators/typed_dict.rs +++ b/src/validators/typed_dict.rs @@ -8,10 +8,10 @@ use ahash::AHashSet; use crate::build_tools::py_schema_err; use crate::build_tools::{is_strict, schema_or_config, schema_or_config_same, ExtraBehavior}; -use crate::errors::{ErrorTypeDefaults, ValError, ValLineError, ValResult}; +use crate::errors::{AsLocItem, ErrorTypeDefaults, ValError, ValLineError, ValResult}; use crate::input::{ AttributesGenericIterator, BorrowInput, DictGenericIterator, GenericMapping, Input, JsonObjectGenericIterator, - MappingGenericIterator, StringMappingGenericIterator, + MappingGenericIterator, StringMappingGenericIterator, ValidationMatch, }; use crate::lookup_key::LookupKey; use crate::tools::SchemaDict; @@ -20,7 +20,7 @@ use super::{ build_validator, BuildValidator, CombinedValidator, DefinitionsBuilder, Extra, ValidationState, Validator, }; -#[derive(Debug, Clone)] +#[derive(Debug)] struct TypedDictField { name: String, lookup_key: LookupKey, @@ -31,7 +31,7 @@ struct TypedDictField { impl_py_gc_traverse!(TypedDictField { validator }); -#[derive(Debug, Clone)] +#[derive(Debug)] pub struct TypedDictValidator { fields: Vec, extra_behavior: ExtraBehavior, @@ -61,7 +61,7 @@ impl BuildValidator for TypedDictValidator { let extra_behavior = ExtraBehavior::from_schema_or_config(py, schema, config, ExtraBehavior::Ignore)?; - let extras_validator = match (schema.get_item(intern!(py, "extras_schema")), &extra_behavior) { + let extras_validator = match (schema.get_item(intern!(py, "extras_schema"))?, &extra_behavior) { (Some(v), ExtraBehavior::Allow) => Some(Box::new(build_validator(v, config, definitions)?)), (Some(_), _) => return py_schema_err!("extras_schema can only be used if extra_behavior=allow"), (_, _) => None, @@ -109,7 +109,7 @@ impl BuildValidator for TypedDictValidator { } } - let lookup_key = match field_info.get_item(intern!(py, "validation_alias")) { + let lookup_key = match field_info.get_item(intern!(py, "validation_alias"))? { Some(alias) => { let alt_alias = if populate_by_name { Some(field_name) } else { None }; LookupKey::from_py(py, alias, alt_alias)? @@ -212,15 +212,35 @@ impl Validator for TypedDictValidator { Err(err) => return ControlFlow::Break(err.into_owned(py)), } continue; - } else if let Some(value) = control_flow!(field.validator.default_value(py, Some(field.name.as_str()), state))? { - control_flow!(output_dict.set_item(&field.name_py, value))?; - } else if field.required { - errors.push(field.lookup_key.error( - ErrorTypeDefaults::Missing, - input, - self.loc_by_alias, - &field.name - )); + } + + match field.validator.default_value(py, Some(field.name.as_str()), state) { + Ok(Some(value)) => { + // Default value exists, and passed validation if required + control_flow!(output_dict.set_item(&field.name_py, value))?; + }, + Ok(None) => { + // This means there was no default value + if (field.required) { + errors.push(field.lookup_key.error( + ErrorTypeDefaults::Missing, + input, + self.loc_by_alias, + &field.name + )); + } + }, + Err(ValError::Omit) => continue, + Err(ValError::LineErrors(line_errors)) => { + for err in line_errors { + // Note: this will always use the field name even if there is an alias + // However, we don't mind so much because this error can only happen if the + // default value fails validation, which is arguably a developer error. + // We could try to "fix" this in the future if desired. + errors.push(err); + } + } + Err(err) => return ControlFlow::Break(err), } } ControlFlow::Continue(()) @@ -232,7 +252,7 @@ impl Validator for TypedDictValidator { if let Some(ref mut used_keys) = used_keys { for item_result in <$iter>::new($dict)? { let (raw_key, value) = item_result?; - let either_str = match raw_key.strict_str() { + let either_str = match raw_key.validate_str(true, false).map(ValidationMatch::into_inner) { Ok(k) => k, Err(ValError::LineErrors(line_errors)) => { for err in line_errors { @@ -307,27 +327,7 @@ impl Validator for TypedDictValidator { } } - fn different_strict_behavior( - &self, - definitions: Option<&DefinitionsBuilder>, - ultra_strict: bool, - ) -> bool { - self.fields - .iter() - .any(|f| f.validator.different_strict_behavior(definitions, ultra_strict)) - } - fn get_name(&self) -> &str { Self::EXPECTED_TYPE } - - fn complete(&mut self, definitions: &DefinitionsBuilder) -> PyResult<()> { - self.fields - .iter_mut() - .try_for_each(|f| f.validator.complete(definitions))?; - match &mut self.extras_validator { - Some(v) => v.complete(definitions), - None => Ok(()), - } - } } diff --git a/src/validators/union.rs b/src/validators/union.rs index 4d3b0bd78..0f8fded07 100644 --- a/src/validators/union.rs +++ b/src/validators/union.rs @@ -8,7 +8,7 @@ use smallvec::SmallVec; use crate::build_tools::py_schema_err; use crate::build_tools::{is_strict, schema_or_config}; -use crate::errors::{ErrorType, LocItem, ValError, ValLineError, ValResult}; +use crate::errors::{AsLocItem, ErrorType, ValError, ValLineError, ValResult}; use crate::input::{GenericMapping, Input}; use crate::lookup_key::LookupKey; use crate::py_gc::PyGcTraverse; @@ -16,40 +16,29 @@ use crate::tools::SchemaDict; use super::custom_error::CustomError; use super::literal::LiteralLookup; -use super::{build_validator, BuildValidator, CombinedValidator, DefinitionsBuilder, ValidationState, Validator}; +use super::{ + build_validator, BuildValidator, CombinedValidator, DefinitionsBuilder, Exactness, ValidationState, Validator, +}; -#[derive(Debug, Clone, Copy)] +#[derive(Debug)] enum UnionMode { - Smart { - strict_required: bool, - ultra_strict_required: bool, - }, + Smart, LeftToRight, } -impl UnionMode { - // construct smart with some default values - const fn default_smart() -> Self { - Self::Smart { - strict_required: true, - ultra_strict_required: false, - } - } -} - impl FromStr for UnionMode { type Err = PyErr; fn from_str(s: &str) -> Result { match s { - "smart" => Ok(Self::default_smart()), + "smart" => Ok(Self::Smart), "left_to_right" => Ok(Self::LeftToRight), s => py_schema_err!("Invalid union mode: `{}`, expected `smart` or `left_to_right`", s), } } } -#[derive(Debug, Clone)] +#[derive(Debug)] pub struct UnionValidator { mode: UnionMode, choices: Vec<(CombinedValidator, Option)>, @@ -87,7 +76,7 @@ impl BuildValidator for UnionValidator { let auto_collapse = || schema.get_as_req(intern!(py, "auto_collapse")).unwrap_or(true); let mode = schema .get_as::<&str>(intern!(py, "mode"))? - .map_or(Ok(UnionMode::default_smart()), UnionMode::from_str)?; + .map_or(Ok(UnionMode::Smart), UnionMode::from_str)?; match choices.len() { 0 => py_schema_err!("One or more union choices required"), 1 if auto_collapse() => Ok(choices.into_iter().next().unwrap().0), @@ -112,71 +101,74 @@ impl BuildValidator for UnionValidator { } impl UnionValidator { - fn validate_smart<'s, 'data>( - &'s self, + fn validate_smart<'data>( + &self, py: Python<'data>, input: &'data impl Input<'data>, state: &mut ValidationState, - strict_required: bool, - ultra_strict_required: bool, ) -> ValResult<'data, PyObject> { - if ultra_strict_required { - // do an ultra strict check first - let state = &mut state.rebind_extra(|extra| { - extra.strict = Some(true); - extra.ultra_strict = true; - }); - if let Some(res) = self - .choices - .iter() - .map(|(validator, _label)| validator.validate(py, input, state)) - .find(ValResult::is_ok) - { - return res; - } - } - + let old_exactness = state.exactness; + let strict = state.strict_or(self.strict); let mut errors = MaybeErrors::new(self.custom_error.as_ref()); - if state.strict_or(self.strict) { - let state = &mut state.rebind_extra(|extra| extra.strict = Some(true)); - for (validator, label) in &self.choices { - match validator.validate(py, input, state) { - Err(ValError::LineErrors(lines)) => errors.push(validator, label.as_deref(), lines), - otherwise => return otherwise, - }; - } + let mut success = None; - Err(errors.into_val_error(input)) - } else { - if strict_required { - // 1st pass: check if the value is an exact instance of one of the Union types, - // e.g. use validate in strict mode - let state = &mut state.rebind_extra(|extra| extra.strict = Some(true)); - if let Some(res) = self - .choices - .iter() - .map(|(validator, _label)| validator.validate(py, input, state)) - .find(ValResult::is_ok) - { - return res; + for (choice, label) in &self.choices { + let state = &mut state.rebind_extra(|extra| { + if strict { + extra.strict = Some(strict); } + }); + state.exactness = Some(Exactness::Exact); + let result = choice.validate(py, input, state); + match result { + Ok(new_success) => match state.exactness { + // exact match, return + Some(Exactness::Exact) => { + return { + // exact match, return, restore any previous exactness + state.exactness = old_exactness; + Ok(new_success) + }; + } + _ => { + // success should always have an exactness + debug_assert_ne!(state.exactness, None); + let new_exactness = state.exactness.unwrap_or(Exactness::Lax); + // if the new result has higher exactness than the current success, replace it + if success + .as_ref() + .map_or(true, |(_, current_exactness)| *current_exactness < new_exactness) + { + // TODO: is there a possible optimization here, where once there has + // been one success, we turn on strict mode, to avoid unnecessary + // coercions for further validation? + success = Some((new_success, new_exactness)); + } + } + }, + Err(ValError::LineErrors(lines)) => { + // if we don't yet know this validation will succeed, record the error + if success.is_none() { + errors.push(choice, label.as_deref(), lines); + } + } + otherwise => return otherwise, } + } + state.exactness = old_exactness; - // 2nd pass: check if the value can be coerced into one of the Union types, e.g. use validate - for (validator, label) in &self.choices { - match validator.validate(py, input, state) { - Err(ValError::LineErrors(lines)) => errors.push(validator, label.as_deref(), lines), - otherwise => return otherwise, - }; - } - - Err(errors.into_val_error(input)) + if let Some((success, exactness)) = success { + state.floor_exactness(exactness); + return Ok(success); } + + // no matches, build errors + Err(errors.into_val_error(input)) } - fn validate_left_to_right<'s, 'data>( - &'s self, + fn validate_left_to_right<'data>( + &self, py: Python<'data>, input: &'data impl Input<'data>, state: &mut ValidationState, @@ -217,47 +209,14 @@ impl Validator for UnionValidator { state: &mut ValidationState, ) -> ValResult<'data, PyObject> { match self.mode { - UnionMode::Smart { - strict_required, - ultra_strict_required, - } => self.validate_smart(py, input, state, strict_required, ultra_strict_required), + UnionMode::Smart => self.validate_smart(py, input, state), UnionMode::LeftToRight => self.validate_left_to_right(py, input, state), } } - fn different_strict_behavior( - &self, - definitions: Option<&DefinitionsBuilder>, - ultra_strict: bool, - ) -> bool { - self.choices - .iter() - .any(|(v, _)| v.different_strict_behavior(definitions, ultra_strict)) - } - fn get_name(&self) -> &str { &self.name } - - fn complete(&mut self, definitions: &DefinitionsBuilder) -> PyResult<()> { - self.choices.iter_mut().try_for_each(|(v, _)| v.complete(definitions))?; - if let UnionMode::Smart { - ref mut strict_required, - ref mut ultra_strict_required, - } = self.mode - { - *strict_required = self - .choices - .iter() - .any(|(v, _)| v.different_strict_behavior(Some(definitions), false)); - *ultra_strict_required = self - .choices - .iter() - .any(|(v, _)| v.different_strict_behavior(Some(definitions), true)); - } - - Ok(()) - } } struct ChoiceLineErrors<'a, 'data> { @@ -357,7 +316,7 @@ impl PyGcTraverse for Discriminator { } } -#[derive(Debug, Clone)] +#[derive(Debug)] pub struct TaggedUnionValidator { discriminator: Discriminator, lookup: LiteralLookup, @@ -476,27 +435,9 @@ impl Validator for TaggedUnionValidator { } } - fn different_strict_behavior( - &self, - definitions: Option<&DefinitionsBuilder>, - ultra_strict: bool, - ) -> bool { - self.lookup - .values - .iter() - .any(|v| v.different_strict_behavior(definitions, ultra_strict)) - } - fn get_name(&self) -> &str { &self.name } - - fn complete(&mut self, definitions: &DefinitionsBuilder) -> PyResult<()> { - self.lookup - .values - .iter_mut() - .try_for_each(|validator| validator.complete(definitions)) - } } impl TaggedUnionValidator { @@ -507,8 +448,8 @@ impl TaggedUnionValidator { ) -> ValResult<'data, &'data PyString> { let dict = input.strict_dict()?; let either_tag = match dict { - GenericMapping::PyDict(dict) => match dict.get_item(intern!(py, "type")) { - Some(t) => t.strict_str()?, + GenericMapping::PyDict(dict) => match dict.get_item(intern!(py, "type"))? { + Some(t) => t.validate_str(true, false)?.into_inner(), None => return Err(self.tag_not_found(input)), }, _ => unreachable!(), @@ -518,8 +459,8 @@ impl TaggedUnionValidator { // custom logic to distinguish between different function and tuple schemas if tag == "function" || tag == "tuple" { let mode = match dict { - GenericMapping::PyDict(dict) => match dict.get_item(intern!(py, "mode")) { - Some(m) => Some(m.strict_str()?), + GenericMapping::PyDict(dict) => match dict.get_item(intern!(py, "mode"))? { + Some(m) => Some(m.validate_str(true, false)?.into_inner()), None => None, }, _ => unreachable!(), @@ -555,7 +496,7 @@ impl TaggedUnionValidator { if let Ok(Some((tag, validator))) = self.lookup.validate(py, tag) { return match validator.validate(py, input, state) { Ok(res) => Ok(res), - Err(err) => Err(err.with_outer_location(LocItem::try_from(tag.to_object(py).into_ref(py))?)), + Err(err) => Err(err.with_outer_location(tag.as_loc_item())), }; } match self.custom_error { diff --git a/src/validators/url.rs b/src/validators/url.rs index 0afc76e59..77f887b7d 100644 --- a/src/validators/url.rs +++ b/src/validators/url.rs @@ -16,6 +16,7 @@ use crate::tools::SchemaDict; use crate::url::{schema_is_special, PyMultiHostUrl, PyUrl}; use super::literal::expected_repr_name; +use super::Exactness; use super::{BuildValidator, CombinedValidator, DefinitionsBuilder, ValidationState, Validator}; type AllowedSchemas = Option<(AHashSet, String)>; @@ -87,32 +88,25 @@ impl Validator for UrlValidator { self.default_port, &self.default_path, ) { - Ok(()) => Ok(PyUrl::new(lib_url).into_py(py)), + Ok(()) => { + // Lax rather than strict to preserve V2.4 semantic that str wins over url in union + state.floor_exactness(Exactness::Lax); + Ok(PyUrl::new(lib_url).into_py(py)) + } Err(error_type) => return Err(ValError::new(error_type, input)), } } - fn different_strict_behavior( - &self, - _definitions: Option<&DefinitionsBuilder>, - ultra_strict: bool, - ) -> bool { - !ultra_strict - } - fn get_name(&self) -> &str { &self.name } - - fn complete(&mut self, _definitions: &DefinitionsBuilder) -> PyResult<()> { - Ok(()) - } } impl UrlValidator { fn get_url<'s, 'data>(&'s self, input: &'data impl Input<'data>, strict: bool) -> ValResult<'data, Url> { match input.validate_str(strict, false) { - Ok(either_str) => { + Ok(val_match) => { + let either_str = val_match.into_inner(); let cow = either_str.as_cow()?; let url_str = cow.as_ref(); @@ -227,32 +221,25 @@ impl Validator for MultiHostUrlValidator { self.default_port, &self.default_path, ) { - Ok(()) => Ok(multi_url.into_py(py)), + Ok(()) => { + // Lax rather than strict to preserve V2.4 semantic that str wins over url in union + state.floor_exactness(Exactness::Lax); + Ok(multi_url.into_py(py)) + } Err(error_type) => return Err(ValError::new(error_type, input)), } } - fn different_strict_behavior( - &self, - _definitions: Option<&DefinitionsBuilder>, - ultra_strict: bool, - ) -> bool { - !ultra_strict - } - fn get_name(&self) -> &str { &self.name } - - fn complete(&mut self, _definitions: &DefinitionsBuilder) -> PyResult<()> { - Ok(()) - } } impl MultiHostUrlValidator { fn get_url<'s, 'data>(&'s self, input: &'data impl Input<'data>, strict: bool) -> ValResult<'data, PyMultiHostUrl> { match input.validate_str(strict, false) { - Ok(either_str) => { + Ok(val_match) => { + let either_str = val_match.into_inner(); let cow = either_str.as_cow()?; let url_str = cow.as_ref(); @@ -506,7 +493,7 @@ fn check_sub_defaults( if let Some(default_port) = default_port { lib_url .set_port(Some(default_port)) - .map_err(|_| map_parse_err(ParseError::EmptyHost))?; + .map_err(|()| map_parse_err(ParseError::EmptyHost))?; } } if let Some(ref default_path) = default_path { diff --git a/src/validators/uuid.rs b/src/validators/uuid.rs index ca924ce66..9e4ce9fb5 100644 --- a/src/validators/uuid.rs +++ b/src/validators/uuid.rs @@ -13,7 +13,7 @@ use crate::tools::SchemaDict; use super::model::create_class; use super::model::force_setattr; -use super::{BuildValidator, CombinedValidator, DefinitionsBuilder, ValidationState, Validator}; +use super::{BuildValidator, CombinedValidator, DefinitionsBuilder, Exactness, ValidationState, Validator}; const UUID_INT: &str = "int"; const UUID_IS_SAFE: &str = "is_safe"; @@ -117,26 +117,19 @@ impl Validator for UuidValidator { input, )) } else { + // In python mode this is a coercion, in JSON mode we treat a UUID string as an + // exact match + if input.is_python() { + state.floor_exactness(Exactness::Lax); + } let uuid = self.get_uuid(input)?; self.create_py_uuid(py, class, &uuid) } } - fn different_strict_behavior( - &self, - _definitions: Option<&DefinitionsBuilder>, - ultra_strict: bool, - ) -> bool { - !ultra_strict - } - fn get_name(&self) -> &str { Self::EXPECTED_TYPE } - - fn complete(&mut self, _definitions: &DefinitionsBuilder) -> PyResult<()> { - Ok(()) - } } impl UuidValidator { @@ -158,7 +151,8 @@ impl UuidValidator { None => { let either_bytes = input .validate_bytes(true) - .map_err(|_| ValError::new(ErrorTypeDefaults::UuidType, input))?; + .map_err(|_| ValError::new(ErrorTypeDefaults::UuidType, input))? + .into_inner(); let bytes_slice = either_bytes.as_slice(); 'parse: { // Try parsing as utf8, but don't care if it fails diff --git a/src/validators/validation_state.rs b/src/validators/validation_state.rs index 6cf5ce313..aacd7d2af 100644 --- a/src/validators/validation_state.rs +++ b/src/validators/validation_state.rs @@ -1,23 +1,26 @@ -use crate::{definitions::Definitions, recursion_guard::RecursionGuard}; +use crate::recursion_guard::RecursionGuard; -use super::{CombinedValidator, Extra}; +use super::Extra; + +#[derive(Clone, Copy, Debug, PartialEq, Eq, Ord, PartialOrd, Hash)] +pub enum Exactness { + Lax, + Strict, + Exact, +} pub struct ValidationState<'a> { pub recursion_guard: &'a mut RecursionGuard, - pub definitions: &'a Definitions, + pub exactness: Option, // deliberately make Extra readonly extra: Extra<'a>, } impl<'a> ValidationState<'a> { - pub fn new( - extra: Extra<'a>, - definitions: &'a Definitions, - recursion_guard: &'a mut RecursionGuard, - ) -> Self { + pub fn new(extra: Extra<'a>, recursion_guard: &'a mut RecursionGuard) -> Self { Self { - recursion_guard, - definitions, + recursion_guard, // Don't care about exactness unless doing union validation + exactness: None, extra, } } @@ -31,10 +34,15 @@ impl<'a> ValidationState<'a> { // but lifetimes get in a tangle. Maybe someone brave wants to have a go at unpicking lifetimes. let mut new_state = ValidationState { recursion_guard: self.recursion_guard, - definitions: self.definitions, + exactness: self.exactness, extra, }; - f(&mut new_state) + let result = f(&mut new_state); + match new_state.exactness { + Some(exactness) => self.floor_exactness(exactness), + None => self.exactness = None, + } + result } /// Temporarily rebinds the extra field by calling `f` to modify extra. @@ -57,6 +65,23 @@ impl<'a> ValidationState<'a> { pub fn strict_or(&self, default: bool) -> bool { self.extra.strict.unwrap_or(default) } + + /// Sets the exactness to the lower of the current exactness + /// and the given exactness. + /// + /// This is designed to be used in union validation, where the + /// idea is that the "most exact" validation wins. + pub fn floor_exactness(&mut self, exactness: Exactness) { + match self.exactness { + None | Some(Exactness::Lax) => {} + Some(Exactness::Strict) => { + if exactness == Exactness::Lax { + self.exactness = Some(Exactness::Lax); + } + } + Some(Exactness::Exact) => self.exactness = Some(exactness), + } + } } pub struct ValidationStateWithReboundExtra<'state, 'a> { diff --git a/src/validators/with_default.rs b/src/validators/with_default.rs index d68590766..a06ccd0cd 100644 --- a/src/validators/with_default.rs +++ b/src/validators/with_default.rs @@ -66,7 +66,7 @@ enum OnError { Default, } -#[derive(Debug, Clone)] +#[derive(Debug)] pub struct WithDefaultValidator { default: DefaultType, on_error: OnError, @@ -182,21 +182,9 @@ impl Validator for WithDefaultValidator { } } - fn different_strict_behavior( - &self, - definitions: Option<&DefinitionsBuilder>, - ultra_strict: bool, - ) -> bool { - self.validator.different_strict_behavior(definitions, ultra_strict) - } - fn get_name(&self) -> &str { &self.name } - - fn complete(&mut self, definitions: &DefinitionsBuilder) -> PyResult<()> { - self.validator.complete(definitions) - } } impl WithDefaultValidator { diff --git a/tests/benchmarks/nested_schema.py b/tests/benchmarks/nested_schema.py new file mode 100644 index 000000000..0d91d1217 --- /dev/null +++ b/tests/benchmarks/nested_schema.py @@ -0,0 +1,93 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Any + +if TYPE_CHECKING: + from pydantic_core import core_schema as cs + +N = 5 # arbitrary number that takes ~0.05s per run + + +class MyModel: + # __slots__ is not required, but it avoids __pydantic_fields_set__ falling into __dict__ + __slots__ = '__dict__', '__pydantic_fields_set__', '__pydantic_extra__', '__pydantic_private__' + + +def schema_using_defs() -> cs.CoreSchema: + definitions: list[cs.CoreSchema] = [ + {'type': 'int', 'ref': 'int'}, + { + 'type': 'model', + 'cls': MyModel, + 'schema': { + 'type': 'model-fields', + 'fields': { + str(c): {'type': 'model-field', 'schema': {'type': 'definition-ref', 'schema_ref': 'int'}} + for c in range(N) + }, + }, + 'ref': f'model_{N}', + }, + ] + level = N + for level in reversed(range(N)): + definitions.append( + { + 'type': 'model', + 'cls': MyModel, + 'schema': { + 'type': 'model-fields', + 'fields': { + str(c): { + 'type': 'model-field', + 'schema': {'type': 'definition-ref', 'schema_ref': f'model_{level+1}'}, + } + for c in range(N) + }, + }, + 'ref': f'model_{level}', + } + ) + return { + 'type': 'definitions', + 'definitions': definitions, + 'schema': {'type': 'definition-ref', 'schema_ref': 'model_0'}, + } + + +def inlined_schema() -> cs.CoreSchema: + level = N + schema: cs.CoreSchema = { + 'type': 'model', + 'cls': MyModel, + 'schema': { + 'type': 'model-fields', + 'fields': {str(c): {'type': 'model-field', 'schema': {'type': 'int'}} for c in range(N)}, + }, + 'ref': f'model_{N}', + } + for level in reversed(range(N)): + schema = { + 'type': 'model', + 'cls': MyModel, + 'schema': { + 'type': 'model-fields', + 'fields': {str(c): {'type': 'model-field', 'schema': schema} for c in range(N)}, + }, + 'ref': f'model_{level}', + } + return schema + + +def input_data_valid(levels: int = N) -> Any: + data = {str(c): 1 for c in range(N)} + for _ in range(levels): + data = {str(c): data for c in range(N)} + return data + + +if __name__ == '__main__': + from pydantic_core import SchemaValidator + + SchemaValidator(schema_using_defs()).validate_python(input_data_valid()) + SchemaValidator(inlined_schema()).validate_python(input_data_valid()) diff --git a/tests/benchmarks/test_nested_benchmark.py b/tests/benchmarks/test_nested_benchmark.py new file mode 100644 index 000000000..6c8d50e83 --- /dev/null +++ b/tests/benchmarks/test_nested_benchmark.py @@ -0,0 +1,23 @@ +""" +Benchmarks for nested / recursive schemas using definitions. +""" + +from typing import Callable + +from pydantic_core import SchemaValidator + +from .nested_schema import inlined_schema, input_data_valid, schema_using_defs + + +def test_nested_schema_using_defs(benchmark: Callable[..., None]) -> None: + v = SchemaValidator(schema_using_defs()) + data = input_data_valid() + v.validate_python(data) + benchmark(v.validate_python, data) + + +def test_nested_schema_inlined(benchmark: Callable[..., None]) -> None: + v = SchemaValidator(inlined_schema()) + data = input_data_valid() + v.validate_python(data) + benchmark(v.validate_python, data) diff --git a/tests/requirements-linting.txt b/tests/requirements-linting.txt index 48e914dbb..38eacfdbc 100644 --- a/tests/requirements-linting.txt +++ b/tests/requirements-linting.txt @@ -1,5 +1,4 @@ -black==23.9.1 -griffe==0.36.2 -pyright==1.1.327 -ruff==0.0.291 -mypy==1.5.1 +griffe==0.36.9 +pyright==1.1.334 +ruff==0.1.5 +mypy==1.6.1 diff --git a/tests/requirements.txt b/tests/requirements.txt index ae8b9e50d..66dda075f 100644 --- a/tests/requirements.txt +++ b/tests/requirements.txt @@ -3,14 +3,14 @@ dirty-equals==0.6.0 hypothesis==6.79.4 # pandas doesn't offer prebuilt wheels for all versions and platforms we test in CI e.g. aarch64 musllinux pandas==2.0.3; python_version >= "3.9" and python_version < "3.12" and implementation_name == "cpython" and platform_machine == 'x86_64' -pytest==7.4.2 +pytest==7.4.3 # we run codspeed benchmarks on x86_64 CPython (i.e. native github actions architecture) pytest-codspeed~=2.2.0; implementation_name == "cpython" and platform_machine == 'x86_64' pytest-examples==0.0.10 pytest-speed==0.3.5 pytest-mock==3.11.1 pytest-pretty==1.2.0 -pytest-timeout==2.1.0 +pytest-timeout==2.2.0 pytz==2023.3.post1 # numpy doesn't offer prebuilt wheels for all versions and platforms we test in CI e.g. aarch64 musllinux numpy==1.25.2; python_version >= "3.9" and python_version < "3.12" and implementation_name == "cpython" and platform_machine == 'x86_64' diff --git a/tests/serializers/test_any.py b/tests/serializers/test_any.py index 448ca108a..98ec22c1f 100644 --- a/tests/serializers/test_any.py +++ b/tests/serializers/test_any.py @@ -505,6 +505,14 @@ class Foo: assert s.to_python(Foo(a='hello', b=b'more'), exclude={'a'}) == IsStrictDict() assert s.to_json(Foo(a='hello', b=b'more'), exclude={'a'}) == b'{}' + assert s.to_python(Foo) == Foo + with pytest.raises(PydanticSerializationError, match=r"Unable to serialize unknown type: "): + s.to_python(Foo, mode='json') + with pytest.raises(PydanticSerializationError, match=r"Unable to serialize unknown type: "): + s.to_json(Foo) + assert s.to_python(Foo, mode='json', fallback=lambda x: x.__name__) == 'Foo' + assert s.to_json(Foo, fallback=lambda x: x.__name__) == b'"Foo"' + def test_dataclass_classvar(any_serializer): @dataclasses.dataclass diff --git a/tests/serializers/test_bytes.py b/tests/serializers/test_bytes.py index e313138c9..13849bed0 100644 --- a/tests/serializers/test_bytes.py +++ b/tests/serializers/test_bytes.py @@ -105,6 +105,13 @@ def test_bytes_base64(): assert base64.b64decode(s.to_python(b'foo bar', mode='json').encode()) == b'foo bar' +def test_bytes_hex(): + s = SchemaSerializer(core_schema.bytes_schema(), {'ser_json_bytes': 'hex'}) + assert s.to_python(b'\xff\xff') == b'\xff\xff' + assert s.to_json(b'\xff\xff') == b'"ffff"' + assert s.to_python(b'\xff\xff', mode='json') == 'ffff' == b'\xff\xff'.hex() + + def test_bytes_base64_dict_key(): s = SchemaSerializer(core_schema.dict_schema(core_schema.bytes_schema()), {'ser_json_bytes': 'base64'}) diff --git a/tests/serializers/test_definitions.py b/tests/serializers/test_definitions.py index 2da4d353d..d45398097 100644 --- a/tests/serializers/test_definitions.py +++ b/tests/serializers/test_definitions.py @@ -113,3 +113,24 @@ def test_use_after(): ) ) assert v.to_python((1, 2)) == ('1', '2') + + +def test_defs_with_dict(): + s = SchemaSerializer( + core_schema.definitions_schema( + schema=core_schema.typed_dict_schema( + { + 'foo': core_schema.typed_dict_field( + core_schema.dict_schema( + keys_schema=core_schema.definition_reference_schema('key'), + values_schema=core_schema.definition_reference_schema('val'), + ) + ) + } + ), + definitions=[core_schema.str_schema(ref='key'), core_schema.str_schema(ref='val')], + ) + ) + + assert s.to_json({'foo': {'key': 'val'}}) == b'{"foo":{"key":"val"}}' + assert s.to_python({'foo': {'key': 'val'}}) == {'foo': {'key': 'val'}} diff --git a/tests/serializers/test_functions.py b/tests/serializers/test_functions.py index 318254602..8851a7d36 100644 --- a/tests/serializers/test_functions.py +++ b/tests/serializers/test_functions.py @@ -228,7 +228,7 @@ def append_args(value, info): 'exclude_defaults=False, exclude_none=False, round_trip=False)' ) assert s.to_json(123) == ( - b'"123 info=SerializationInfo(include=None, exclude=None, mode=\'json\', by_alias=True, exclude_unset=False, ' + b"\"123 info=SerializationInfo(include=None, exclude=None, mode='json', by_alias=True, exclude_unset=False, " b'exclude_defaults=False, exclude_none=False, round_trip=False)"' ) diff --git a/tests/serializers/test_pickling.py b/tests/serializers/test_pickling.py new file mode 100644 index 000000000..2ca230313 --- /dev/null +++ b/tests/serializers/test_pickling.py @@ -0,0 +1,50 @@ +import json +import pickle +from datetime import timedelta + +import pytest + +from pydantic_core import core_schema +from pydantic_core._pydantic_core import SchemaSerializer + + +def repr_function(value, _info): + return repr(value) + + +def test_basic_schema_serializer(): + s = SchemaSerializer(core_schema.dict_schema()) + s = pickle.loads(pickle.dumps(s)) + assert s.to_python({'a': 1, b'b': 2, 33: 3}) == {'a': 1, b'b': 2, 33: 3} + assert s.to_python({'a': 1, b'b': 2, 33: 3, True: 4}, mode='json') == {'a': 1, 'b': 2, '33': 3, 'true': 4} + assert s.to_json({'a': 1, b'b': 2, 33: 3, True: 4}) == b'{"a":1,"b":2,"33":3,"true":4}' + + assert s.to_python({(1, 2): 3}) == {(1, 2): 3} + assert s.to_python({(1, 2): 3}, mode='json') == {'1,2': 3} + assert s.to_json({(1, 2): 3}) == b'{"1,2":3}' + + +@pytest.mark.parametrize( + 'value,expected_python,expected_json', + [(None, 'None', b'"None"'), (1, '1', b'"1"'), ([1, 2, 3], '[1, 2, 3]', b'"[1, 2, 3]"')], +) +def test_schema_serializer_capturing_function(value, expected_python, expected_json): + # Test a SchemaSerializer that captures a function. + s = SchemaSerializer( + core_schema.any_schema( + serialization=core_schema.plain_serializer_function_ser_schema(repr_function, info_arg=True) + ) + ) + s = pickle.loads(pickle.dumps(s)) + assert s.to_python(value) == expected_python + assert s.to_json(value) == expected_json + assert s.to_python(value, mode='json') == json.loads(expected_json) + + +def test_schema_serializer_containing_config(): + s = SchemaSerializer(core_schema.timedelta_schema(), config={'ser_json_timedelta': 'float'}) + s = pickle.loads(pickle.dumps(s)) + + assert s.to_python(timedelta(seconds=4, microseconds=500_000)) == timedelta(seconds=4, microseconds=500_000) + assert s.to_python(timedelta(seconds=4, microseconds=500_000), mode='json') == 4.5 + assert s.to_json(timedelta(seconds=4, microseconds=500_000)) == b'4.5' diff --git a/tests/serializers/test_simple.py b/tests/serializers/test_simple.py index 9bbaad05b..b63208c07 100644 --- a/tests/serializers/test_simple.py +++ b/tests/serializers/test_simple.py @@ -136,3 +136,30 @@ def test_numpy(): assert type(v) == float assert s.to_json(numpy.float64(1.0)) == b'1.0' + + +@pytest.mark.parametrize( + 'value,expected_json,config', + [ + # default values of ser_json_inf_nan + (float('inf'), 'null', {}), + (float('-inf'), 'null', {}), + (float('nan'), 'null', {}), + # explicit values of ser_json_inf_nan + (float('inf'), 'null', {'ser_json_inf_nan': 'null'}), + (float('-inf'), 'null', {'ser_json_inf_nan': 'null'}), + (float('nan'), 'null', {'ser_json_inf_nan': 'null'}), + (float('inf'), 'Infinity', {'ser_json_inf_nan': 'constants'}), + (float('-inf'), '-Infinity', {'ser_json_inf_nan': 'constants'}), + (float('nan'), 'NaN', {'ser_json_inf_nan': 'constants'}), + ], +) +def test_float_inf_and_nan_serializers(value, expected_json, config): + s = SchemaSerializer(core_schema.float_schema(), config) + + # Python can represent these values without needing any changes + assert s.to_python(value) is value + assert s.to_python(value, mode='json') is value + + # Serialized JSON value respects the ser_json_inf_nan setting + assert s.to_json(value).decode() == expected_json diff --git a/tests/serializers/test_union.py b/tests/serializers/test_union.py index f81e33a6b..9b021e66e 100644 --- a/tests/serializers/test_union.py +++ b/tests/serializers/test_union.py @@ -454,3 +454,59 @@ class Item(BaseModel): ) assert s.to_python(DBUser(name='John', password='secret')) == {'name': 'John'} + + +def test_union_serializes_list_of_model_subclass_from_definition() -> None: + class BaseModel: + __slots__ = '__dict__', '__pydantic_fields_set__', '__pydantic_extra__', '__pydantic_private__' + + def __init__(self, **kwargs: Any): + for key, value in kwargs.items(): + setattr(self, key, value) + + class User(BaseModel): + name: str + + class DBUser(User): + password: str + __pydantic_serializer__: ClassVar[SchemaSerializer] + + DBUser.__pydantic_serializer__ = SchemaSerializer( + core_schema.model_schema( + DBUser, + core_schema.model_fields_schema( + { + 'name': core_schema.model_field(core_schema.str_schema()), + 'password': core_schema.model_field(core_schema.str_schema()), + } + ), + ) + ) + + class Item(BaseModel): + price: float + + s = SchemaSerializer( + core_schema.definitions_schema( + core_schema.union_schema( + [ + core_schema.list_schema(core_schema.definition_reference_schema('User'), strict=False), + core_schema.list_schema(core_schema.definition_reference_schema('Item'), strict=False), + ] + ), + [ + core_schema.model_schema( + User, + core_schema.model_fields_schema({'name': core_schema.model_field(core_schema.str_schema())}), + ref='User', + ), + core_schema.model_schema( + Item, + core_schema.model_fields_schema({'price': core_schema.model_field(core_schema.float_schema())}), + ref='Item', + ), + ], + ) + ) + + assert s.to_python([DBUser(name='John', password='secret')]) == [{'name': 'John'}] diff --git a/tests/test.rs b/tests/test.rs index 526b30e5e..348520435 100644 --- a/tests/test.rs +++ b/tests/test.rs @@ -46,7 +46,7 @@ mod tests { ] }"#; let schema: &PyDict = py.eval(code, None, None).unwrap().extract().unwrap(); - SchemaSerializer::py_new(schema, None).unwrap(); + SchemaSerializer::py_new(py, schema, None).unwrap(); }); } @@ -75,9 +75,9 @@ a = A() "#; let locals = PyDict::new(py); py.run(code, None, Some(locals)).unwrap(); - let a: &PyAny = locals.get_item("a").unwrap().extract().unwrap(); - let schema: &PyDict = locals.get_item("schema").unwrap().extract().unwrap(); - let serialized: Vec = SchemaSerializer::py_new(schema, None) + let a: &PyAny = locals.get_item("a").unwrap().unwrap().extract().unwrap(); + let schema: &PyDict = locals.get_item("schema").unwrap().unwrap().extract().unwrap(); + let serialized: Vec = SchemaSerializer::py_new(py, schema, None) .unwrap() .to_json(py, a, None, None, None, true, false, false, false, false, true, None) .unwrap() diff --git a/tests/test_errors.py b/tests/test_errors.py index 293880977..05815aec5 100644 --- a/tests/test_errors.py +++ b/tests/test_errors.py @@ -289,7 +289,9 @@ def f(input_value, info): ('string_unicode', 'Input should be a valid string, unable to parse raw data as a unicode string', None), ('string_pattern_mismatch', "String should match pattern 'foo'", {'pattern': 'foo'}), ('string_too_short', 'String should have at least 42 characters', {'min_length': 42}), + ('string_too_short', 'String should have at least 1 character', {'min_length': 1}), ('string_too_long', 'String should have at most 42 characters', {'max_length': 42}), + ('string_too_long', 'String should have at most 1 character', {'max_length': 1}), ('dict_type', 'Input should be a valid dictionary', None), ('mapping_type', 'Input should be a valid mapping, error: foobar', {'error': 'foobar'}), ('iterable_type', 'Input should be iterable', None), @@ -312,7 +314,9 @@ def f(input_value, info): ('float_parsing', 'Input should be a valid number, unable to parse string as a number', None), ('bytes_type', 'Input should be a valid bytes', None), ('bytes_too_short', 'Data should have at least 42 bytes', {'min_length': 42}), + ('bytes_too_short', 'Data should have at least 1 byte', {'min_length': 1}), ('bytes_too_long', 'Data should have at most 42 bytes', {'max_length': 42}), + ('bytes_too_long', 'Data should have at most 1 byte', {'max_length': 1}), ('value_error', 'Value error, foobar', {'error': ValueError('foobar')}), ('assertion_error', 'Assertion failed, foobar', {'error': AssertionError('foobar')}), ('literal_error', 'Input should be foo', {'expected': 'foo'}), @@ -356,6 +360,7 @@ def f(input_value, info): ('url_parsing', 'Input should be a valid URL, Foobar', {'error': 'Foobar'}), ('url_syntax_violation', 'Input violated strict URL syntax rules, Foobar', {'error': 'Foobar'}), ('url_too_long', 'URL should have at most 42 characters', {'max_length': 42}), + ('url_too_long', 'URL should have at most 1 character', {'max_length': 1}), ('url_scheme', 'URL scheme should be "foo", "bar" or "spam"', {'expected_schemes': '"foo", "bar" or "spam"'}), ('uuid_type', 'UUID input should be a string, bytes or UUID object', None), ('uuid_parsing', 'Input should be a valid UUID, Foobar', {'error': 'Foobar'}), @@ -363,12 +368,19 @@ def f(input_value, info): ('decimal_type', 'Decimal input should be an integer, float, string or Decimal object', None), ('decimal_parsing', 'Input should be a valid decimal', None), ('decimal_max_digits', 'Decimal input should have no more than 42 digits in total', {'max_digits': 42}), + ('decimal_max_digits', 'Decimal input should have no more than 1 digit in total', {'max_digits': 1}), ('decimal_max_places', 'Decimal input should have no more than 42 decimal places', {'decimal_places': 42}), + ('decimal_max_places', 'Decimal input should have no more than 1 decimal place', {'decimal_places': 1}), ( 'decimal_whole_digits', 'Decimal input should have no more than 42 digits before the decimal point', {'whole_digits': 42}, ), + ( + 'decimal_whole_digits', + 'Decimal input should have no more than 1 digit before the decimal point', + {'whole_digits': 1}, + ), ] @@ -777,7 +789,7 @@ def raise_py_error(v: Any) -> Any: with pytest.raises(ValidationError) as exc_info: s.validate_python('anything') - exc = exc_info.value.errors()[0]['ctx']['error'] # type: ignore + exc = exc_info.value.errors()[0]['ctx']['error'] assert isinstance(exc, ValueError) assert isinstance(exc.__context__, AssertionError) @@ -1038,9 +1050,9 @@ def test_loc_with_dots(pydantic_version): ] # insert_assert(str(exc_info.value)) assert str(exc_info.value) == ( - "1 validation error for typed-dict\n" - "`foo.bar`.0\n" - " Input should be a valid integer, unable to parse string as an integer " + '1 validation error for typed-dict\n' + '`foo.bar`.0\n' + ' Input should be a valid integer, unable to parse string as an integer ' "[type=int_parsing, input_value='x', input_type=str]\n" f' For further information visit https://errors.pydantic.dev/{pydantic_version}/v/int_parsing' ) diff --git a/tests/test_garbage_collection.py b/tests/test_garbage_collection.py index d848c91ea..97107e61b 100644 --- a/tests/test_garbage_collection.py +++ b/tests/test_garbage_collection.py @@ -27,7 +27,9 @@ class BaseModel: __schema__: SchemaSerializer def __init_subclass__(cls) -> None: - cls.__schema__ = SchemaSerializer(core_schema.model_schema(cls, GC_TEST_SCHEMA_INNER)) + cls.__schema__ = SchemaSerializer( + core_schema.model_schema(cls, GC_TEST_SCHEMA_INNER), config={'ser_json_timedelta': 'float'} + ) cache: 'WeakValueDictionary[int, Any]' = WeakValueDictionary() @@ -56,7 +58,10 @@ class BaseModel: __validator__: SchemaValidator def __init_subclass__(cls) -> None: - cls.__validator__ = SchemaValidator(core_schema.model_schema(cls, GC_TEST_SCHEMA_INNER)) + cls.__validator__ = SchemaValidator( + core_schema.model_schema(cls, GC_TEST_SCHEMA_INNER), + config=core_schema.CoreConfig(extra_fields_behavior='allow'), + ) cache: 'WeakValueDictionary[int, Any]' = WeakValueDictionary() diff --git a/tests/test_json.py b/tests/test_json.py index 9bba05c14..4ef8a1d40 100644 --- a/tests/test_json.py +++ b/tests/test_json.py @@ -4,7 +4,7 @@ from typing import List import pytest -from dirty_equals import IsList +from dirty_equals import IsFloatNan, IsList import pydantic_core from pydantic_core import ( @@ -358,3 +358,10 @@ def test_bad_repr(): to_json(b) assert to_json(b, serialize_unknown=True) == b'""' + + +def test_inf_nan_allow(): + v = SchemaValidator(core_schema.float_schema(allow_inf_nan=True)) + assert v.validate_json('Infinity') == float('inf') + assert v.validate_json('-Infinity') == float('-inf') + assert v.validate_json('NaN') == IsFloatNan() diff --git a/tests/test_typing.py b/tests/test_typing.py index 0d527c619..dcd2f267a 100644 --- a/tests/test_typing.py +++ b/tests/test_typing.py @@ -23,6 +23,10 @@ def foo(bar: str) -> None: ... +def validator_deprecated(value: Any, info: core_schema.FieldValidationInfo) -> None: + ... + + def validator(value: Any, info: core_schema.ValidationInfo) -> None: ... diff --git a/tests/validators/test_arguments.py b/tests/validators/test_arguments.py index 4ef581b47..4cf1d3ad2 100644 --- a/tests/validators/test_arguments.py +++ b/tests/validators/test_arguments.py @@ -1009,11 +1009,11 @@ def test_error_display(pydantic_version): ] # insert_assert(str(exc_info.value)) assert str(exc_info.value) == ( - "1 validation error for arguments\n" - "b\n" - " Missing required argument [type=missing_argument, " + '1 validation error for arguments\n' + 'b\n' + ' Missing required argument [type=missing_argument, ' "input_value=ArgsKwargs((), {'a': 1}), input_type=ArgsKwargs]\n" - f" For further information visit https://errors.pydantic.dev/{pydantic_version}/v/missing_argument" + f' For further information visit https://errors.pydantic.dev/{pydantic_version}/v/missing_argument' ) # insert_assert(exc_info.value.json(include_url=False)) assert exc_info.value.json(include_url=False) == ( diff --git a/tests/validators/test_bool.py b/tests/validators/test_bool.py index e71d41cb1..3ac900701 100644 --- a/tests/validators/test_bool.py +++ b/tests/validators/test_bool.py @@ -63,7 +63,7 @@ def test_bool_error(pydantic_version): '1 validation error for bool\n' ' Input should be a valid boolean, ' "unable to interpret input [type=bool_parsing, input_value='wrong', input_type=str]\n" - f" For further information visit https://errors.pydantic.dev/{pydantic_version}/v/bool_parsing" + f' For further information visit https://errors.pydantic.dev/{pydantic_version}/v/bool_parsing' ) assert exc_info.value.errors(include_url=False) == [ { diff --git a/tests/validators/test_date.py b/tests/validators/test_date.py index 5ddde4884..6a552a57b 100644 --- a/tests/validators/test_date.py +++ b/tests/validators/test_date.py @@ -64,6 +64,8 @@ ), id='-inf', ), + pytest.param('-', Err('Input should be a valid date or datetime, input is too short'), id='minus'), + pytest.param('+', Err('Input should be a valid date or datetime, input is too short'), id='pus'), ], ) def test_date(input_value, expected): diff --git a/tests/validators/test_datetime.py b/tests/validators/test_datetime.py index df04d1631..89e9c1c53 100644 --- a/tests/validators/test_datetime.py +++ b/tests/validators/test_datetime.py @@ -1,6 +1,5 @@ import copy import json -import pickle import platform import re from datetime import date, datetime, time, timedelta, timezone, tzinfo @@ -37,6 +36,8 @@ (float('nan'), Err('Input should be a valid datetime, NaN values not permitted [type=datetime_parsing,')), (float('inf'), Err('Input should be a valid datetime, dates after 9999')), (float('-inf'), Err('Input should be a valid datetime, dates before 1600')), + ('-', Err('Input should be a valid datetime, input is too short [type=datetime_parsing,')), + ('+', Err('Input should be a valid datetime, input is too short [type=datetime_parsing,')), ], ) def test_datetime(input_value, expected): @@ -480,17 +481,6 @@ def test_tz_constraint_wrong(): validate_core_schema(core_schema.datetime_schema(tz_constraint='wrong')) -def test_tz_pickle() -> None: - """ - https://github.com/pydantic/pydantic-core/issues/589 - """ - v = SchemaValidator(core_schema.datetime_schema()) - original = datetime(2022, 6, 8, 12, 13, 14, tzinfo=timezone(timedelta(hours=-12, minutes=-15))) - validated = v.validate_python('2022-06-08T12:13:14-12:15') - assert validated == original - assert pickle.loads(pickle.dumps(validated)) == validated == original - - def test_tz_hash() -> None: v = SchemaValidator(core_schema.datetime_schema()) lookup: Dict[datetime, str] = {} diff --git a/tests/validators/test_decimal.py b/tests/validators/test_decimal.py index 376a9816a..b9fabeaed 100644 --- a/tests/validators/test_decimal.py +++ b/tests/validators/test_decimal.py @@ -140,7 +140,8 @@ def test_decimal_strict_json(input_value, expected): {'ge': 0}, -0.1, Err( - 'Input should be greater than or equal to 0 [type=greater_than_equal, input_value=-0.1, input_type=float]' + 'Input should be greater than or equal to 0 ' + '[type=greater_than_equal, input_value=-0.1, input_type=float]' ), ), ({'gt': 0}, 0.1, Decimal('0.1')), @@ -148,6 +149,10 @@ def test_decimal_strict_json(input_value, expected): ({'le': 0}, 0, Decimal(0)), ({'le': 0}, -1, Decimal(-1)), ({'le': 0}, 0.1, Err('Input should be less than or equal to 0')), + ({'lt': 0, 'allow_inf_nan': True}, float('nan'), Err('Input should be less than 0')), + ({'gt': 0, 'allow_inf_nan': True}, float('inf'), Decimal('inf')), + ({'allow_inf_nan': True}, float('-inf'), Decimal('-inf')), + ({'allow_inf_nan': True}, float('nan'), FunctionCheck(math.isnan)), ({'lt': 0}, 0, Err('Input should be less than 0')), ({'lt': 0.123456}, 1, Err('Input should be less than 0.123456')), ], @@ -433,3 +438,31 @@ def test_non_finite_constrained_decimal_values(input_value, allow_inf_nan, expec def test_validate_scientific_notation_from_json(input_value, expected): v = SchemaValidator({'type': 'decimal'}) assert v.validate_json(input_value) == expected + + +def test_validate_max_digits_and_decimal_places() -> None: + v = SchemaValidator({'type': 'decimal', 'max_digits': 5, 'decimal_places': 2}) + + # valid inputs + assert v.validate_json('1.23') == Decimal('1.23') + assert v.validate_json('123.45') == Decimal('123.45') + assert v.validate_json('-123.45') == Decimal('-123.45') + + # invalid inputs + with pytest.raises(ValidationError): + v.validate_json('1234.56') # too many digits + with pytest.raises(ValidationError): + v.validate_json('123.456') # too many decimal places + with pytest.raises(ValidationError): + v.validate_json('123456') # too many digits + with pytest.raises(ValidationError): + v.validate_json('abc') # not a valid decimal + + +def test_validate_max_digits_and_decimal_places_edge_case() -> None: + v = SchemaValidator({'type': 'decimal', 'max_digits': 34, 'decimal_places': 18}) + + # valid inputs + assert v.validate_python(Decimal('9999999999999999.999999999999999999')) == Decimal( + '9999999999999999.999999999999999999' + ) diff --git a/tests/validators/test_definitions_recursive.py b/tests/validators/test_definitions_recursive.py index b836eb7a1..9f7a93d57 100644 --- a/tests/validators/test_definitions_recursive.py +++ b/tests/validators/test_definitions_recursive.py @@ -1,3 +1,4 @@ +import datetime import platform from dataclasses import dataclass from typing import List, Optional @@ -243,7 +244,7 @@ class Branch: def test_invalid_schema(): - with pytest.raises(SchemaError, match='Definitions error: attempted to use `Branch` before it was filled'): + with pytest.raises(SchemaError, match='Definitions error: definition `Branch` was never filled'): SchemaValidator( { 'type': 'list', @@ -895,3 +896,211 @@ class Model: 'url': f'https://errors.pydantic.dev/{pydantic_version}/v/dataclass_type', } ] + + +def test_cyclic_data() -> None: + cyclic_data = {} + cyclic_data['b'] = {'a': cyclic_data} + + schema = core_schema.definitions_schema( + core_schema.definition_reference_schema('a'), + [ + core_schema.typed_dict_schema( + { + 'b': core_schema.typed_dict_field( + core_schema.nullable_schema(core_schema.definition_reference_schema('b')) + ) + }, + ref='a', + ), + core_schema.typed_dict_schema( + { + 'a': core_schema.typed_dict_field( + core_schema.nullable_schema(core_schema.definition_reference_schema('a')) + ) + }, + ref='b', + ), + ], + ) + + validator = SchemaValidator(schema) + + with pytest.raises(ValidationError) as exc_info: + validator.validate_python(cyclic_data) + + assert exc_info.value.title == 'typed-dict' + assert exc_info.value.errors(include_url=False) == [ + { + 'type': 'recursion_loop', + 'loc': ('b', 'a'), + 'msg': 'Recursion error - cyclic reference detected', + 'input': cyclic_data, + } + ] + + +def test_cyclic_data_threeway() -> None: + cyclic_data = {} + cyclic_data['b'] = {'c': {'a': cyclic_data}} + + schema = core_schema.definitions_schema( + core_schema.definition_reference_schema('a'), + [ + core_schema.typed_dict_schema( + { + 'b': core_schema.typed_dict_field( + core_schema.nullable_schema(core_schema.definition_reference_schema('b')) + ) + }, + ref='a', + ), + core_schema.typed_dict_schema( + { + 'c': core_schema.typed_dict_field( + core_schema.nullable_schema(core_schema.definition_reference_schema('c')) + ) + }, + ref='b', + ), + core_schema.typed_dict_schema( + { + 'a': core_schema.typed_dict_field( + core_schema.nullable_schema(core_schema.definition_reference_schema('a')) + ) + }, + ref='c', + ), + ], + ) + + validator = SchemaValidator(schema) + + with pytest.raises(ValidationError) as exc_info: + validator.validate_python(cyclic_data) + + assert exc_info.value.title == 'typed-dict' + assert exc_info.value.errors(include_url=False) == [ + { + 'type': 'recursion_loop', + 'loc': ('b', 'c', 'a'), + 'msg': 'Recursion error - cyclic reference detected', + 'input': cyclic_data, + } + ] + + +def test_complex_recursive_type() -> None: + schema = core_schema.definitions_schema( + core_schema.definition_reference_schema('JsonType'), + [ + core_schema.nullable_schema( + core_schema.union_schema( + [ + core_schema.list_schema(core_schema.definition_reference_schema('JsonType')), + core_schema.dict_schema( + core_schema.str_schema(), core_schema.definition_reference_schema('JsonType') + ), + core_schema.str_schema(), + core_schema.int_schema(), + core_schema.float_schema(), + core_schema.bool_schema(), + ] + ), + ref='JsonType', + ) + ], + ) + + validator = SchemaValidator(schema) + + with pytest.raises(ValidationError) as exc_info: + validator.validate_python({'a': datetime.date(year=1992, month=12, day=11)}) + + assert exc_info.value.errors(include_url=False) == [ + { + 'type': 'list_type', + 'loc': ('list[nullable[union[list[...],dict[str,...],str,int,float,bool]]]',), + 'msg': 'Input should be a valid list', + 'input': {'a': datetime.date(1992, 12, 11)}, + }, + { + 'type': 'list_type', + 'loc': ('dict[str,...]', 'a', 'list[nullable[union[list[...],dict[str,...],str,int,float,bool]]]'), + 'msg': 'Input should be a valid list', + 'input': datetime.date(1992, 12, 11), + }, + { + 'type': 'dict_type', + 'loc': ('dict[str,...]', 'a', 'dict[str,...]'), + 'msg': 'Input should be a valid dictionary', + 'input': datetime.date(1992, 12, 11), + }, + { + 'type': 'string_type', + 'loc': ('dict[str,...]', 'a', 'str'), + 'msg': 'Input should be a valid string', + 'input': datetime.date(1992, 12, 11), + }, + { + 'type': 'int_type', + 'loc': ('dict[str,...]', 'a', 'int'), + 'msg': 'Input should be a valid integer', + 'input': datetime.date(1992, 12, 11), + }, + { + 'type': 'float_type', + 'loc': ('dict[str,...]', 'a', 'float'), + 'msg': 'Input should be a valid number', + 'input': datetime.date(1992, 12, 11), + }, + { + 'type': 'bool_type', + 'loc': ('dict[str,...]', 'a', 'bool'), + 'msg': 'Input should be a valid boolean', + 'input': datetime.date(1992, 12, 11), + }, + { + 'type': 'string_type', + 'loc': ('str',), + 'msg': 'Input should be a valid string', + 'input': {'a': datetime.date(1992, 12, 11)}, + }, + { + 'type': 'int_type', + 'loc': ('int',), + 'msg': 'Input should be a valid integer', + 'input': {'a': datetime.date(1992, 12, 11)}, + }, + { + 'type': 'float_type', + 'loc': ('float',), + 'msg': 'Input should be a valid number', + 'input': {'a': datetime.date(1992, 12, 11)}, + }, + { + 'type': 'bool_type', + 'loc': ('bool',), + 'msg': 'Input should be a valid boolean', + 'input': {'a': datetime.date(1992, 12, 11)}, + }, + ] + + +def test_no_exponential_blowup(): + """See https://github.com/pydantic/pydantic/issues/8049 + + There was a performance bug which led to exponential blowup when trying to + build a schema with many intermingled recursive definitions. + """ + unions = core_schema.union_schema([core_schema.definition_reference_schema(f'foo_{i}') for i in range(100)]) + + schema = core_schema.definitions_schema( + core_schema.typed_dict_schema({'x': core_schema.typed_dict_field(unions)}), + definitions=[ + core_schema.typed_dict_schema({'a': core_schema.typed_dict_field(unions)}, ref=f'foo_{i}') + for i in range(100) + ], + ) + + SchemaValidator(schema) diff --git a/tests/validators/test_float.py b/tests/validators/test_float.py index 74f0024ca..4e3bda0c4 100644 --- a/tests/validators/test_float.py +++ b/tests/validators/test_float.py @@ -4,9 +4,9 @@ from typing import Any, Dict import pytest -from dirty_equals import FunctionCheck, IsStr +from dirty_equals import FunctionCheck, IsFloatNan, IsStr -from pydantic_core import SchemaValidator, ValidationError +from pydantic_core import SchemaValidator, ValidationError, core_schema from ..conftest import Err, PyAndJson, plain_repr @@ -86,6 +86,8 @@ def test_float_strict(py_and_json: PyAndJson, input_value, expected): ({'le': 0}, 0.1, Err('Input should be less than or equal to 0')), ({'lt': 0}, 0, Err('Input should be less than 0')), ({'lt': 0.123456}, 1, Err('Input should be less than 0.123456')), + ({'lt': 0, 'allow_inf_nan': True}, float('nan'), Err('Input should be less than 0')), + ({'gt': 0, 'allow_inf_nan': True}, float('inf'), float('inf')), ], ) def test_float_kwargs(py_and_json: PyAndJson, kwargs: Dict[str, Any], input_value, expected): @@ -225,6 +227,7 @@ def test_float_key(py_and_json: PyAndJson): ('NaN', True, FunctionCheck(math.isnan)), ('NaN', False, Err("Input should be a finite number [type=finite_number, input_value='NaN', input_type=str]")), ('+inf', True, FunctionCheck(lambda x: math.isinf(x) and x > 0)), + ('inf', True, FunctionCheck(lambda x: math.isinf(x) and x > 0)), ( '+inf', False, @@ -372,3 +375,34 @@ def test_string_with_underscores() -> None: v.validate_python(edge_case) with pytest.raises(ValidationError): v.validate_json(f'"{edge_case}"') + + +def test_allow_inf_nan_true_json() -> None: + v = SchemaValidator(core_schema.float_schema()) + + assert v.validate_json('123') == 123 + assert v.validate_json('NaN') == IsFloatNan() + assert v.validate_json('Infinity') == float('inf') + assert v.validate_json('-Infinity') == float('-inf') + + +def test_allow_inf_nan_false_json() -> None: + v = SchemaValidator(core_schema.float_schema(), core_schema.CoreConfig(allow_inf_nan=False)) + + assert v.validate_json('123') == 123 + with pytest.raises(ValidationError) as exc_info1: + v.validate_json('NaN') + # insert_assert(exc_info.value.errors()) + assert exc_info1.value.errors(include_url=False) == [ + {'type': 'finite_number', 'loc': (), 'msg': 'Input should be a finite number', 'input': IsFloatNan()} + ] + with pytest.raises(ValidationError) as exc_info2: + v.validate_json('Infinity') + assert exc_info2.value.errors(include_url=False) == [ + {'type': 'finite_number', 'loc': (), 'msg': 'Input should be a finite number', 'input': float('inf')} + ] + with pytest.raises(ValidationError) as exc_info3: + v.validate_json('-Infinity') + assert exc_info3.value.errors(include_url=False) == [ + {'type': 'finite_number', 'loc': (), 'msg': 'Input should be a finite number', 'input': float('-inf')} + ] diff --git a/tests/validators/test_function.py b/tests/validators/test_function.py index 9f94ceb1b..e5ccba1e3 100644 --- a/tests/validators/test_function.py +++ b/tests/validators/test_function.py @@ -289,8 +289,19 @@ def f(input_value, validator, info): v = SchemaValidator(core_schema.with_info_wrap_validator_function(f, core_schema.int_schema())) - with pytest.raises(TypeError, match='^outer_location must be a str or int$'): - v.validate_python(4) + assert v.validate_python(4) == 6 + + with pytest.raises(ValidationError) as exc_info: + v.validate_python('wrong') + # insert_assert(exc_info.value.errors(include_url=False)) + assert exc_info.value.errors(include_url=False) == [ + { + 'type': 'int_parsing', + 'loc': ("('4',)",), + 'msg': 'Input should be a valid integer, unable to parse string as an integer', + 'input': 'wrong', + } + ] def test_function_after(): diff --git a/tests/validators/test_int.py b/tests/validators/test_int.py index 8d5850dc8..80dd1cf73 100644 --- a/tests/validators/test_int.py +++ b/tests/validators/test_int.py @@ -6,7 +6,7 @@ import pytest from dirty_equals import IsStr -from pydantic_core import SchemaValidator, ValidationError +from pydantic_core import SchemaValidator, ValidationError, core_schema from ..conftest import Err, PyAndJson, plain_repr @@ -357,10 +357,10 @@ def test_too_long(pydantic_version): ] # insert_assert(repr(exc_info.value)) assert repr(exc_info.value) == ( - "1 validation error for int\n" - " Unable to parse input string as an integer, exceeded maximum size " + '1 validation error for int\n' + ' Unable to parse input string as an integer, exceeded maximum size ' "[type=int_parsing_size, input_value='111111111111111111111111...11111111111111111111111', input_type=str]\n" - f" For further information visit https://errors.pydantic.dev/{pydantic_version}/v/int_parsing_size" + f' For further information visit https://errors.pydantic.dev/{pydantic_version}/v/int_parsing_size' ) @@ -459,3 +459,40 @@ def test_float_subclass() -> None: v_lax = v.validate_python(FloatSubclass(1)) assert v_lax == 1 assert type(v_lax) == int + + +def test_int_subclass_plain_enum() -> None: + v = SchemaValidator({'type': 'int'}) + + from enum import Enum + + class PlainEnum(Enum): + ONE = 1 + + v_lax = v.validate_python(PlainEnum.ONE) + assert v_lax == 1 + assert type(v_lax) == int + + +def test_allow_inf_nan_true_json() -> None: + v = SchemaValidator(core_schema.int_schema(), core_schema.CoreConfig(allow_inf_nan=True)) + + assert v.validate_json('123') == 123 + with pytest.raises(ValidationError, match=r'Input should be a finite number \[type=finite_number'): + v.validate_json('NaN') + with pytest.raises(ValidationError, match=r'Input should be a finite number \[type=finite_number'): + v.validate_json('Infinity') + with pytest.raises(ValidationError, match=r'Input should be a finite number \[type=finite_number'): + v.validate_json('-Infinity') + + +def test_allow_inf_nan_false_json() -> None: + v = SchemaValidator(core_schema.int_schema(), core_schema.CoreConfig(allow_inf_nan=False)) + + assert v.validate_json('123') == 123 + with pytest.raises(ValidationError, match=r'Input should be a finite number \[type=finite_number'): + v.validate_json('NaN') + with pytest.raises(ValidationError, match=r'Input should be a finite number \[type=finite_number'): + v.validate_json('Infinity') + with pytest.raises(ValidationError, match=r'Input should be a finite number \[type=finite_number'): + v.validate_json('-Infinity') diff --git a/tests/validators/test_literal.py b/tests/validators/test_literal.py index e1397aeea..d294f866c 100644 --- a/tests/validators/test_literal.py +++ b/tests/validators/test_literal.py @@ -78,7 +78,7 @@ pytest.param( ['a', 'b'], 'c', - Err("Input should be 'a' or 'b' [type=literal_error, input_value=\'c\', input_type=str]"), + Err("Input should be 'a' or 'b' [type=literal_error, input_value='c', input_type=str]"), id='wrong-multiple-str', ), ([1, '1'], 1, 1), diff --git a/tests/validators/test_pickling.py b/tests/validators/test_pickling.py new file mode 100644 index 000000000..2037ab8c9 --- /dev/null +++ b/tests/validators/test_pickling.py @@ -0,0 +1,53 @@ +import pickle +import re +from datetime import datetime, timedelta, timezone + +import pytest + +from pydantic_core import core_schema, validate_core_schema +from pydantic_core._pydantic_core import SchemaValidator, ValidationError + + +def test_basic_schema_validator(): + v = SchemaValidator( + validate_core_schema( + {'type': 'dict', 'strict': True, 'keys_schema': {'type': 'int'}, 'values_schema': {'type': 'int'}} + ) + ) + v = pickle.loads(pickle.dumps(v)) + assert v.validate_python({'1': 2, '3': 4}) == {1: 2, 3: 4} + assert v.validate_python({}) == {} + with pytest.raises(ValidationError, match=re.escape('[type=dict_type, input_value=[], input_type=list]')): + v.validate_python([]) + + +def test_schema_validator_containing_config(): + """ + Verify that the config object is not lost during (de)serialization. + """ + v = SchemaValidator( + core_schema.model_fields_schema({'f': core_schema.model_field(core_schema.str_schema())}), + config=core_schema.CoreConfig(extra_fields_behavior='allow'), + ) + v = pickle.loads(pickle.dumps(v)) + + m, model_extra, fields_set = v.validate_python({'f': 'x', 'extra_field': '123'}) + assert m == {'f': 'x'} + # If the config was lost during (de)serialization, the below checks would fail as + # the default behavior is to ignore extra fields. + assert model_extra == {'extra_field': '123'} + assert fields_set == {'f', 'extra_field'} + + v.validate_assignment(m, 'f', 'y') + assert m == {'f': 'y'} + + +def test_schema_validator_tz_pickle() -> None: + """ + https://github.com/pydantic/pydantic-core/issues/589 + """ + v = SchemaValidator(core_schema.datetime_schema()) + original = datetime(2022, 6, 8, 12, 13, 14, tzinfo=timezone(timedelta(hours=-12, minutes=-15))) + validated = v.validate_python('2022-06-08T12:13:14-12:15') + assert validated == original + assert pickle.loads(pickle.dumps(validated)) == validated == original diff --git a/tests/validators/test_string.py b/tests/validators/test_string.py index acb145a58..22bcd5445 100644 --- a/tests/validators/test_string.py +++ b/tests/validators/test_string.py @@ -249,6 +249,21 @@ def test_lax_subclass(FruitEnum, kwargs): assert repr(p) == "'pear'" +@pytest.mark.parametrize('kwargs', [{}, {'to_lower': True}], ids=repr) +def test_lax_subclass_plain_enum(kwargs): + v = SchemaValidator(core_schema.str_schema(**kwargs)) + + from enum import Enum + + class PlainEnum(Enum): + ONE = 'one' + + p = v.validate_python(PlainEnum.ONE) + assert p == 'one' + assert type(p) is str + assert repr(p) == "'one'" + + def test_subclass_preserved() -> None: class StrSubclass(str): pass @@ -277,6 +292,16 @@ def test_coerce_numbers_to_str_disabled_in_strict_mode() -> None: v.validate_json('42') +def test_coerce_numbers_to_str_raises_for_bool() -> None: + config = core_schema.CoreConfig(coerce_numbers_to_str=True) + + v = SchemaValidator(core_schema.str_schema(), config) + with pytest.raises(ValidationError): + v.validate_python(True) + with pytest.raises(ValidationError): + v.validate_json(False) + + @pytest.mark.parametrize( ('number', 'expected_str'), [ @@ -321,9 +346,9 @@ def test_backtracking_regex_rust_unsupported(mode) -> None: SchemaValidator(core_schema.str_schema(pattern=pattern), core_schema.CoreConfig(regex_engine='rust-regex')) assert exc_info.value.args[0] == ( - 'Error building \"str\" validator:\n' + 'Error building "str" validator:\n' ' SchemaError: regex parse error:\n' - ' r(#*)\".*?\"\\1\n' + ' r(#*)".*?"\\1\n' ' ^^\n' 'error: backreferences are not supported' ) diff --git a/tests/validators/test_typed_dict.py b/tests/validators/test_typed_dict.py index 5f0729d25..8fb25cff6 100644 --- a/tests/validators/test_typed_dict.py +++ b/tests/validators/test_typed_dict.py @@ -107,10 +107,10 @@ def test_missing_error(pydantic_version): v.validate_python({'field_a': b'abc'}) # insert_assert(str(exc_info.value)) assert str(exc_info.value) == ( - "1 validation error for typed-dict\n" - "field_b\n" + '1 validation error for typed-dict\n' + 'field_b\n' " Field required [type=missing, input_value={'field_a': b'abc'}, input_type=dict]\n" - f" For further information visit https://errors.pydantic.dev/{pydantic_version}/v/missing" + f' For further information visit https://errors.pydantic.dev/{pydantic_version}/v/missing' ) diff --git a/tests/validators/test_union.py b/tests/validators/test_union.py index ad51fb447..503a5f387 100644 --- a/tests/validators/test_union.py +++ b/tests/validators/test_union.py @@ -1,3 +1,9 @@ +from dataclasses import dataclass +from datetime import date, time +from enum import Enum, IntEnum +from typing import Any +from uuid import UUID + import pytest from dirty_equals import IsFloat, IsInt @@ -342,9 +348,6 @@ def test_dirty_behaviour(): def test_int_float(): v = SchemaValidator(core_schema.union_schema([core_schema.int_schema(), core_schema.float_schema()])) - assert 'strict_required:true' in plain_repr(v) - assert 'ultra_strict_required:true' in plain_repr(v) # since "float" schema has ultra-strict behaviour - assert v.validate_python(1) == IsInt(approx=1, delta=0) assert v.validate_json('1') == IsInt(approx=1, delta=0) assert v.validate_python(1.0) == IsFloat(approx=1, delta=0) @@ -382,17 +385,8 @@ def test_str_float(): assert v.validate_json('"1"') == '1' -def test_strict_check(): - v = SchemaValidator(core_schema.union_schema([core_schema.int_schema(), core_schema.json_schema()])) - assert 'strict_required:true' in plain_repr(v) - assert 'ultra_strict_required:false' in plain_repr(v) - - def test_no_strict_check(): v = SchemaValidator(core_schema.union_schema([core_schema.is_instance_schema(int), core_schema.json_schema()])) - assert 'strict_required:false' in plain_repr(v) - assert 'ultra_strict_required:false' in plain_repr(v) - assert v.validate_python(123) == 123 assert v.validate_python('[1, 2, 3]') == [1, 2, 3] @@ -414,8 +408,6 @@ def test_strict_reference(): ], ) ) - assert 'strict_required:true' in plain_repr(v) - assert 'ultra_strict_required:true' in plain_repr(v) # since "float" schema has ultra-strict behaviour assert repr(v.validate_python((1, 2))) == '(1.0, 2)' assert repr(v.validate_python((1.0, (2.0, 3)))) == '(1.0, (2.0, 3))' @@ -501,3 +493,281 @@ def test_left_to_right_union_strict(): out = v.validate_python(1) assert out == 1.0 assert isinstance(out, float) + + +def test_union_function_before_called_once(): + # See https://github.com/pydantic/pydantic/issues/6830 - in particular the + # smart union validator used to call `remove_prefix` twice, which is not + # ideal from a user perspective. + class SpecialValues(str, Enum): + DEFAULT = 'default' + OTHER = 'other' + + special_values_schema = core_schema.no_info_after_validator_function(SpecialValues, core_schema.str_schema()) + + validator_called_count = 0 + + def remove_prefix(v: str): + nonlocal validator_called_count + validator_called_count += 1 + if v.startswith('uuid::'): + return v[6:] + return v + + prefixed_uuid_schema = core_schema.no_info_before_validator_function(remove_prefix, core_schema.uuid_schema()) + + v = SchemaValidator(core_schema.union_schema([special_values_schema, prefixed_uuid_schema])) + + assert v.validate_python('uuid::12345678-1234-5678-1234-567812345678') == UUID( + '12345678-1234-5678-1234-567812345678' + ) + assert validator_called_count == 1 + + +@pytest.mark.parametrize( + ('schema', 'input_value', 'expected_value'), + ( + ( + core_schema.uuid_schema(), + '12345678-1234-5678-1234-567812345678', + UUID('12345678-1234-5678-1234-567812345678'), + ), + (core_schema.date_schema(), '2020-01-01', date(2020, 1, 1)), + (core_schema.time_schema(), '00:00:00', time(0, 0, 0)), + # In V2.4 these already returned strings, so we keep this behaviour in V2 + (core_schema.datetime_schema(), '2020-01-01:00:00:00', '2020-01-01:00:00:00'), + (core_schema.url_schema(), 'https://foo.com', 'https://foo.com'), + (core_schema.multi_host_url_schema(), 'https://bar.com,foo.com', 'https://bar.com,foo.com'), + ), +) +def test_smart_union_json_string_types(schema: core_schema.CoreSchema, input_value: str, expected_value: Any): + # Many types have to be represented in strings as JSON, we make sure that + # when parsing in JSON mode these types are preferred + # TODO: in V3 we will make str win in all these cases. + + validator = SchemaValidator(core_schema.union_schema([schema, core_schema.str_schema()])) + assert validator.validate_json(f'"{input_value}"') == expected_value + # in Python mode the string will be preferred + assert validator.validate_python(input_value) == input_value + + +@pytest.mark.parametrize( + ('schema', 'input_value'), + ( + pytest.param( + core_schema.uuid_schema(), + '12345678-1234-5678-1234-567812345678', + marks=pytest.mark.xfail(reason='TODO: V3'), + ), + (core_schema.date_schema(), '2020-01-01'), + (core_schema.time_schema(), '00:00:00'), + (core_schema.datetime_schema(), '2020-01-01:00:00:00'), + (core_schema.url_schema(), 'https://foo.com'), + (core_schema.multi_host_url_schema(), 'https://bar.com,foo.com'), + ), +) +def test_smart_union_json_string_types_str_first(schema: core_schema.CoreSchema, input_value: str): + # As above, but reversed order; str should always win + validator = SchemaValidator(core_schema.union_schema([core_schema.str_schema(), schema])) + assert validator.validate_json(f'"{input_value}"') == input_value + assert validator.validate_python(input_value) == input_value + + +def test_smart_union_default_fallback(): + """Using a default value does not affect the exactness of the smart union match.""" + + class ModelA: + x: int + y: int = 1 + + class ModelB: + x: int + + schema = core_schema.union_schema( + [ + core_schema.model_schema( + ModelA, + core_schema.model_fields_schema( + { + 'x': core_schema.model_field(core_schema.int_schema()), + 'y': core_schema.model_field( + core_schema.with_default_schema(core_schema.int_schema(), default=1) + ), + } + ), + ), + core_schema.model_schema( + ModelB, core_schema.model_fields_schema({'x': core_schema.model_field(core_schema.int_schema())}) + ), + ] + ) + + validator = SchemaValidator(schema) + + result = validator.validate_python({'x': 1}) + assert isinstance(result, ModelA) + assert result.x == 1 + assert result.y == 1 + + # passing a ModelB explicitly will not match the default value + b = ModelB() + assert validator.validate_python(b) is b + + +def test_smart_union_model_field(): + class ModelA: + x: int + + class ModelB: + x: str + + schema = core_schema.union_schema( + [ + core_schema.model_schema( + ModelA, core_schema.model_fields_schema({'x': core_schema.model_field(core_schema.int_schema())}) + ), + core_schema.model_schema( + ModelB, core_schema.model_fields_schema({'x': core_schema.model_field(core_schema.str_schema())}) + ), + ] + ) + + validator = SchemaValidator(schema) + + result = validator.validate_python({'x': 1}) + assert isinstance(result, ModelA) + assert result.x == 1 + + result = validator.validate_python({'x': '1'}) + assert isinstance(result, ModelB) + assert result.x == '1' + + +def test_smart_union_dataclass_field(): + @dataclass + class ModelA: + x: int + + @dataclass + class ModelB: + x: str + + schema = core_schema.union_schema( + [ + core_schema.dataclass_schema( + ModelA, + core_schema.dataclass_args_schema( + 'ModelA', [core_schema.dataclass_field('x', core_schema.int_schema())] + ), + ['x'], + ), + core_schema.dataclass_schema( + ModelB, + core_schema.dataclass_args_schema( + 'ModelB', [core_schema.dataclass_field('x', core_schema.str_schema())] + ), + ['x'], + ), + ] + ) + + validator = SchemaValidator(schema) + + result = validator.validate_python({'x': 1}) + assert isinstance(result, ModelA) + assert result.x == 1 + + result = validator.validate_python({'x': '1'}) + assert isinstance(result, ModelB) + assert result.x == '1' + + +def test_smart_union_with_any(): + """any is preferred over lax validations""" + + # str not coerced to int + schema = core_schema.union_schema([core_schema.int_schema(), core_schema.any_schema()]) + validator = SchemaValidator(schema) + assert validator.validate_python('1') == '1' + + # int *is* coerced to float, this is a strict validation + schema = core_schema.union_schema([core_schema.float_schema(), core_schema.any_schema()]) + validator = SchemaValidator(schema) + assert repr(validator.validate_python(1)) == '1.0' + + +def test_smart_union_validator_function(): + """adding a validator function should not change smart union behaviour""" + + inner_schema = core_schema.union_schema([core_schema.int_schema(), core_schema.float_schema()]) + + validator = SchemaValidator(inner_schema) + assert repr(validator.validate_python(1)) == '1' + assert repr(validator.validate_python(1.0)) == '1.0' + + schema = core_schema.union_schema( + [core_schema.no_info_after_validator_function(lambda v: v * 2, inner_schema), core_schema.str_schema()] + ) + + validator = SchemaValidator(schema) + assert repr(validator.validate_python(1)) == '2' + assert repr(validator.validate_python(1.0)) == '2.0' + assert validator.validate_python('1') == '1' + + schema = core_schema.union_schema( + [ + core_schema.no_info_wrap_validator_function(lambda v, handler: handler(v) * 2, inner_schema), + core_schema.str_schema(), + ] + ) + + validator = SchemaValidator(schema) + assert repr(validator.validate_python(1)) == '2' + assert repr(validator.validate_python(1.0)) == '2.0' + assert validator.validate_python('1') == '1' + + +def test_smart_union_validator_function_one_arm(): + """adding a validator function should not change smart union behaviour""" + + schema = core_schema.union_schema( + [ + core_schema.float_schema(), + core_schema.no_info_after_validator_function(lambda v: v * 2, core_schema.int_schema()), + ] + ) + + validator = SchemaValidator(schema) + assert repr(validator.validate_python(1)) == '2' + assert repr(validator.validate_python(1.0)) == '1.0' + + schema = core_schema.union_schema( + [ + core_schema.float_schema(), + core_schema.no_info_wrap_validator_function(lambda v, handler: handler(v) * 2, core_schema.int_schema()), + ] + ) + + validator = SchemaValidator(schema) + assert repr(validator.validate_python(1)) == '2' + assert repr(validator.validate_python(1.0)) == '1.0' + + +def test_int_not_coerced_to_enum(): + class BinaryEnum(IntEnum): + ZERO = 0 + ONE = 1 + + enum_schema = core_schema.lax_or_strict_schema( + core_schema.no_info_after_validator_function(BinaryEnum, core_schema.int_schema()), + core_schema.is_instance_schema(BinaryEnum), + ) + + schema = core_schema.union_schema([enum_schema, core_schema.int_schema()]) + + validator = SchemaValidator(schema) + + assert validator.validate_python(0) is not BinaryEnum.ZERO + assert validator.validate_python(1) is not BinaryEnum.ONE + assert validator.validate_python(BinaryEnum.ZERO) is BinaryEnum.ZERO + assert validator.validate_python(BinaryEnum.ONE) is BinaryEnum.ONE diff --git a/tests/validators/test_with_default.py b/tests/validators/test_with_default.py index 808e4807d..7ca0d9f54 100644 --- a/tests/validators/test_with_default.py +++ b/tests/validators/test_with_default.py @@ -654,3 +654,153 @@ def _validator(cls, v, info): gc.collect() assert ref() is None + + +validate_default_raises_examples = [ + ( + {}, + [ + {'type': 'assertion_error', 'loc': ('x',), 'msg': 'Assertion failed, ', 'input': None}, + {'type': 'assertion_error', 'loc': ('y',), 'msg': 'Assertion failed, ', 'input': None}, + {'type': 'missing', 'loc': ('z',), 'msg': 'Field required', 'input': {}}, + ], + ), + ( + {'z': 'some str'}, + [ + {'type': 'assertion_error', 'loc': ('x',), 'msg': 'Assertion failed, ', 'input': None}, + {'type': 'assertion_error', 'loc': ('y',), 'msg': 'Assertion failed, ', 'input': None}, + ], + ), + ( + {'x': None}, + [ + {'type': 'assertion_error', 'loc': ('x',), 'msg': 'Assertion failed, ', 'input': None}, + {'type': 'assertion_error', 'loc': ('y',), 'msg': 'Assertion failed, ', 'input': None}, + {'type': 'missing', 'loc': ('z',), 'msg': 'Field required', 'input': {'x': None}}, + ], + ), + ( + {'x': None, 'z': 'some str'}, + [ + {'type': 'assertion_error', 'loc': ('x',), 'msg': 'Assertion failed, ', 'input': None}, + {'type': 'assertion_error', 'loc': ('y',), 'msg': 'Assertion failed, ', 'input': None}, + ], + ), + ( + {'y': None}, + [ + {'type': 'assertion_error', 'loc': ('x',), 'msg': 'Assertion failed, ', 'input': None}, + {'type': 'assertion_error', 'loc': ('y',), 'msg': 'Assertion failed, ', 'input': None}, + {'type': 'missing', 'loc': ('z',), 'msg': 'Field required', 'input': {'y': None}}, + ], + ), + ( + {'y': None, 'z': 'some str'}, + [ + {'type': 'assertion_error', 'loc': ('x',), 'msg': 'Assertion failed, ', 'input': None}, + {'type': 'assertion_error', 'loc': ('y',), 'msg': 'Assertion failed, ', 'input': None}, + ], + ), + ( + {'x': None, 'y': None}, + [ + {'type': 'assertion_error', 'loc': ('x',), 'msg': 'Assertion failed, ', 'input': None}, + {'type': 'assertion_error', 'loc': ('y',), 'msg': 'Assertion failed, ', 'input': None}, + {'type': 'missing', 'loc': ('z',), 'msg': 'Field required', 'input': {'x': None, 'y': None}}, + ], + ), + ( + {'x': None, 'y': None, 'z': 'some str'}, + [ + {'type': 'assertion_error', 'loc': ('x',), 'msg': 'Assertion failed, ', 'input': None}, + {'type': 'assertion_error', 'loc': ('y',), 'msg': 'Assertion failed, ', 'input': None}, + ], + ), + ( + {'x': 1, 'y': None, 'z': 'some str'}, + [ + {'type': 'assertion_error', 'loc': ('x',), 'msg': 'Assertion failed, ', 'input': 1}, + {'type': 'assertion_error', 'loc': ('y',), 'msg': 'Assertion failed, ', 'input': None}, + ], + ), + ( + {'x': None, 'y': 1, 'z': 'some str'}, + [ + {'type': 'assertion_error', 'loc': ('x',), 'msg': 'Assertion failed, ', 'input': None}, + {'type': 'assertion_error', 'loc': ('y',), 'msg': 'Assertion failed, ', 'input': 1}, + ], + ), + ( + {'x': 1, 'y': 1, 'z': 'some str'}, + [ + {'type': 'assertion_error', 'loc': ('x',), 'msg': 'Assertion failed, ', 'input': 1}, + {'type': 'assertion_error', 'loc': ('y',), 'msg': 'Assertion failed, ', 'input': 1}, + ], + ), +] + + +@pytest.mark.parametrize( + 'core_schema_constructor,field_constructor', + [ + (core_schema.model_fields_schema, core_schema.model_field), + (core_schema.typed_dict_schema, core_schema.typed_dict_field), + ], +) +@pytest.mark.parametrize('input_value,expected', validate_default_raises_examples) +def test_validate_default_raises( + core_schema_constructor: Union[core_schema.ModelFieldsSchema, core_schema.TypedDictSchema], + field_constructor: Union[core_schema.model_field, core_schema.typed_dict_field], + input_value: dict, + expected: Any, +) -> None: + def _raise(ex: Exception) -> None: + raise ex() + + inner_schema = core_schema.no_info_after_validator_function( + lambda x: _raise(AssertionError), core_schema.nullable_schema(core_schema.int_schema()) + ) + + v = SchemaValidator( + core_schema_constructor( + { + 'x': field_constructor( + core_schema.with_default_schema(inner_schema, default=None, validate_default=True) + ), + 'y': field_constructor( + core_schema.with_default_schema(inner_schema, default=None, validate_default=True) + ), + 'z': field_constructor(core_schema.str_schema()), + } + ) + ) + + with pytest.raises(ValidationError) as exc_info: + v.validate_python(input_value) + assert exc_info.value.errors(include_url=False, include_context=False) == expected + + +@pytest.mark.parametrize('input_value,expected', validate_default_raises_examples) +def test_validate_default_raises_dataclass(input_value: dict, expected: Any) -> None: + def _raise(ex: Exception) -> None: + raise ex() + + inner_schema = core_schema.no_info_after_validator_function( + lambda x: _raise(AssertionError), core_schema.nullable_schema(core_schema.int_schema()) + ) + + x = core_schema.dataclass_field( + name='x', schema=core_schema.with_default_schema(inner_schema, default=None, validate_default=True) + ) + y = core_schema.dataclass_field( + name='y', schema=core_schema.with_default_schema(inner_schema, default=None, validate_default=True) + ) + z = core_schema.dataclass_field(name='z', schema=core_schema.str_schema()) + + v = SchemaValidator(core_schema.dataclass_args_schema('XYZ', [x, y, z])) + + with pytest.raises(ValidationError) as exc_info: + v.validate_python(input_value) + + assert exc_info.value.errors(include_url=False, include_context=False) == expected