1
+ """
2
+ Regression test for extract_aigrant_companies functionality.
3
+
4
+ This test verifies that data extraction works correctly by extracting
5
+ companies that received AI grants along with their batch numbers,
6
+ based on the TypeScript extract_aigrant_companies evaluation.
7
+ """
8
+
9
+ import os
10
+ import pytest
11
+ import pytest_asyncio
12
+ from pydantic import BaseModel , Field
13
+ from typing import List
14
+
15
+ from stagehand import Stagehand , StagehandConfig
16
+ from stagehand .schemas import ExtractOptions
17
+
18
+
19
+ class Company (BaseModel ):
20
+ company : str = Field (..., description = "The name of the company" )
21
+ batch : str = Field (..., description = "The batch number of the grant" )
22
+
23
+
24
+ class Companies (BaseModel ):
25
+ companies : List [Company ] = Field (..., description = "List of companies that received AI grants" )
26
+
27
+
28
+ class TestExtractAigrantCompanies :
29
+ """Regression test for extract_aigrant_companies functionality"""
30
+
31
+ @pytest .fixture (scope = "class" )
32
+ def local_config (self ):
33
+ """Configuration for LOCAL mode testing"""
34
+ return StagehandConfig (
35
+ env = "LOCAL" ,
36
+ model_name = "gpt-4o-mini" ,
37
+ headless = True ,
38
+ verbose = 1 ,
39
+ dom_settle_timeout_ms = 2000 ,
40
+ model_client_options = {"apiKey" : os .getenv ("MODEL_API_KEY" ) or os .getenv ("OPENAI_API_KEY" )},
41
+ )
42
+
43
+ @pytest .fixture (scope = "class" )
44
+ def browserbase_config (self ):
45
+ """Configuration for BROWSERBASE mode testing"""
46
+ return StagehandConfig (
47
+ env = "BROWSERBASE" ,
48
+ api_key = os .getenv ("BROWSERBASE_API_KEY" ),
49
+ project_id = os .getenv ("BROWSERBASE_PROJECT_ID" ),
50
+ model_name = "gpt-4o" ,
51
+ headless = False ,
52
+ verbose = 2 ,
53
+ model_client_options = {"apiKey" : os .getenv ("MODEL_API_KEY" ) or os .getenv ("OPENAI_API_KEY" )},
54
+ )
55
+
56
+ @pytest_asyncio .fixture
57
+ async def local_stagehand (self , local_config ):
58
+ """Create a Stagehand instance for LOCAL testing"""
59
+ stagehand = Stagehand (config = local_config )
60
+ await stagehand .init ()
61
+ yield stagehand
62
+ await stagehand .close ()
63
+
64
+ @pytest_asyncio .fixture
65
+ async def browserbase_stagehand (self , browserbase_config ):
66
+ """Create a Stagehand instance for BROWSERBASE testing"""
67
+ if not (os .getenv ("BROWSERBASE_API_KEY" ) and os .getenv ("BROWSERBASE_PROJECT_ID" )):
68
+ pytest .skip ("Browserbase credentials not available" )
69
+
70
+ stagehand = Stagehand (config = browserbase_config )
71
+ await stagehand .init ()
72
+ yield stagehand
73
+ await stagehand .close ()
74
+
75
+ @pytest .mark .asyncio
76
+ @pytest .mark .regression
77
+ @pytest .mark .local
78
+ async def test_extract_aigrant_companies_local (self , local_stagehand ):
79
+ """
80
+ Regression test: extract_aigrant_companies
81
+
82
+ Mirrors the TypeScript extract_aigrant_companies evaluation:
83
+ - Navigate to AI grant companies test site
84
+ - Extract all companies that received AI grants with their batch numbers
85
+ - Verify total count is 91
86
+ - Verify first company is "Goodfire" in batch "4"
87
+ - Verify last company is "Forefront" in batch "1"
88
+ """
89
+ stagehand = local_stagehand
90
+
91
+ await stagehand .page .goto ("https://browserbase.github.io/stagehand-eval-sites/sites/aigrant/" )
92
+
93
+ # Extract all companies with their batch numbers
94
+ extract_options = ExtractOptions (
95
+ instruction = (
96
+ "Extract all companies that received the AI grant and group them with their "
97
+ "batch numbers as an array of objects. Each object should contain the company "
98
+ "name and its corresponding batch number."
99
+ ),
100
+ schema_definition = Companies
101
+ )
102
+
103
+ result = await stagehand .page .extract (extract_options )
104
+
105
+ # Both LOCAL and BROWSERBASE modes return the Pydantic model instance directly
106
+ companies = result .companies
107
+
108
+ # Verify total count
109
+ expected_length = 91
110
+ assert len (companies ) == expected_length , (
111
+ f"Expected { expected_length } companies, but got { len (companies )} "
112
+ )
113
+
114
+ # Verify first company
115
+ expected_first_item = {
116
+ "company" : "Goodfire" ,
117
+ "batch" : "4"
118
+ }
119
+ assert len (companies ) > 0 , "No companies were extracted"
120
+ first_company = companies [0 ]
121
+ assert first_company .company == expected_first_item ["company" ], (
122
+ f"Expected first company to be '{ expected_first_item ['company' ]} ', "
123
+ f"but got '{ first_company .company } '"
124
+ )
125
+ assert first_company .batch == expected_first_item ["batch" ], (
126
+ f"Expected first company batch to be '{ expected_first_item ['batch' ]} ', "
127
+ f"but got '{ first_company .batch } '"
128
+ )
129
+
130
+ # Verify last company
131
+ expected_last_item = {
132
+ "company" : "Forefront" ,
133
+ "batch" : "1"
134
+ }
135
+ last_company = companies [- 1 ]
136
+ assert last_company .company == expected_last_item ["company" ], (
137
+ f"Expected last company to be '{ expected_last_item ['company' ]} ', "
138
+ f"but got '{ last_company .company } '"
139
+ )
140
+ assert last_company .batch == expected_last_item ["batch" ], (
141
+ f"Expected last company batch to be '{ expected_last_item ['batch' ]} ', "
142
+ f"but got '{ last_company .batch } '"
143
+ )
144
+
145
+ @pytest .mark .asyncio
146
+ @pytest .mark .regression
147
+ @pytest .mark .api
148
+ @pytest .mark .skipif (
149
+ not (os .getenv ("BROWSERBASE_API_KEY" ) and os .getenv ("BROWSERBASE_PROJECT_ID" )),
150
+ reason = "Browserbase credentials not available"
151
+ )
152
+ async def test_extract_aigrant_companies_browserbase (self , browserbase_stagehand ):
153
+ """
154
+ Regression test: extract_aigrant_companies (Browserbase)
155
+
156
+ Same test as local but running in Browserbase environment.
157
+ """
158
+ stagehand = browserbase_stagehand
159
+
160
+ await stagehand .page .goto ("https://browserbase.github.io/stagehand-eval-sites/sites/aigrant/" )
161
+
162
+ # Extract all companies with their batch numbers
163
+ extract_options = ExtractOptions (
164
+ instruction = (
165
+ "Extract all companies that received the AI grant and group them with their "
166
+ "batch numbers as an array of objects. Each object should contain the company "
167
+ "name and its corresponding batch number."
168
+ ),
169
+ schema_definition = Companies
170
+ )
171
+
172
+ result = await stagehand .page .extract (extract_options )
173
+
174
+ # Both LOCAL and BROWSERBASE modes return the Pydantic model instance directly
175
+ companies = result .companies
176
+
177
+ # Verify total count
178
+ expected_length = 91
179
+ assert len (companies ) == expected_length , (
180
+ f"Expected { expected_length } companies, but got { len (companies )} "
181
+ )
182
+
183
+ # Verify first company
184
+ expected_first_item = {
185
+ "company" : "Goodfire" ,
186
+ "batch" : "4"
187
+ }
188
+ assert len (companies ) > 0 , "No companies were extracted"
189
+ first_company = companies [0 ]
190
+ assert first_company .company == expected_first_item ["company" ], (
191
+ f"Expected first company to be '{ expected_first_item ['company' ]} ', "
192
+ f"but got '{ first_company .company } '"
193
+ )
194
+ assert first_company .batch == expected_first_item ["batch" ], (
195
+ f"Expected first company batch to be '{ expected_first_item ['batch' ]} ', "
196
+ f"but got '{ first_company .batch } '"
197
+ )
198
+
199
+ # Verify last company
200
+ expected_last_item = {
201
+ "company" : "Forefront" ,
202
+ "batch" : "1"
203
+ }
204
+ last_company = companies [- 1 ]
205
+ assert last_company .company == expected_last_item ["company" ], (
206
+ f"Expected last company to be '{ expected_last_item ['company' ]} ', "
207
+ f"but got '{ last_company .company } '"
208
+ )
209
+ assert last_company .batch == expected_last_item ["batch" ], (
210
+ f"Expected last company batch to be '{ expected_last_item ['batch' ]} ', "
211
+ f"but got '{ last_company .batch } '"
212
+ )
0 commit comments