@@ -139,7 +139,8 @@ def defects4j_codegen_output(input_file, output_file, num_output=10):
139139
140140 # 生成多个输出
141141 outputs = []
142- for _ in range (num_output ):
142+ for i in range (num_output ):
143+ print (f"Generating output { i + 1 } /{ num_output } " )
143144 # 构建 prompt
144145 prompt = f"[INST] This is an incorrect code({ filename } ):\n ```java\n { code } \n ```\n You are a software engineer. Can you repair the incorrect code?\n [/INST]\n ```java\n "
145146 print (f"Prompt: { prompt } " , flush = True )
@@ -149,30 +150,36 @@ def defects4j_codegen_output(input_file, output_file, num_output=10):
149150 max_d = 500 - cnt
150151
151152 while True :
152- response = pipe (
153- prompt ,
154- min_length = cnt + 64 ,
155- max_length = cnt + max_d ,
156- temperature = 1.0 ,
157- do_sample = True
158- )[0 ]['generated_text' ]
159-
160- # 提取生成的代码
161153 try :
162- print ('response:' , response )
163- generated_code = response .split ('[/INST]' )[1 ].strip ()
164- if generated_code .startswith ('```java\n ' ):
165- generated_code = generated_code [8 :]
166- if generated_code .endswith ('```' ):
167- generated_code = generated_code [:- 3 ]
168- generated_code = generated_code .strip ()
154+ response = pipe (
155+ prompt ,
156+ min_length = cnt + 64 ,
157+ max_length = cnt + max_d ,
158+ temperature = 1.0 ,
159+ do_sample = True
160+ )[0 ]['generated_text' ]
161+
162+ print ('Full response:' , response )
163+
164+ # 提取生成的代码
165+ if '[/INST]' in response :
166+ generated_code = response .split ('[/INST]' )[1 ].strip ()
167+ if generated_code .startswith ('```java\n ' ):
168+ generated_code = generated_code [8 :]
169+ if generated_code .endswith ('```' ):
170+ generated_code = generated_code [:- 3 ]
171+ generated_code = generated_code .strip ()
172+
173+ if generated_code and not generated_code .isspace ():
174+ print (f"Generated code: { generated_code } " )
175+ outputs .append (generated_code )
176+ break
169177
170- if generated_code != '' :
171- break
172178 except Exception as e :
173- print (f"Error processing { filename } : { e } " )
179+ print (f"Error in generation : { e } " )
174180 pass
175181
182+ print (f"Retrying with increased length..." )
176183 max_d = min (2000 - cnt , max_d + 500 )
177184
178185 outputs .append (generated_code )
0 commit comments